Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
142 commits
Select commit Hold shift + click to select a range
5b3092a
Changed VERSION to 2.9.0.dev0
ptrendx Sep 19, 2025
57b4d7b
[JAX] Remove import jax.extend.ffi (#2193)
phu0ngng Sep 22, 2025
5e4e0b2
[PyTorch] Add sink attention support from cuDNN (#2148)
cyanguwa Sep 22, 2025
2db20a6
[QA] Add pytest xml report for all tests in qa folder that use pytest…
shengfangd Sep 23, 2025
a92a0ad
[JAX] Local-Amax for Current-Scaling (#2183)
mingxu1067 Sep 23, 2025
3f875fb
[JAX] Restore Shardy Rule with CompoundFactor (#2167)
phu0ngng Sep 23, 2025
afd15a1
[JAX] Update JAX version requirement in pyproject.toml (#2197)
phu0ngng Sep 24, 2025
9e72796
[PyTorch] Unpin version of onnxscript and onnxruntime (#2202)
pggPL Sep 26, 2025
4d14578
[JAX] Fix XML filename in the L0_jax_uniitest (#2205)
phu0ngng Sep 27, 2025
d75bf43
[JAX] CollectiveGemm (#2166)
phu0ngng Sep 27, 2025
a91e458
[JAX] Add xml export for `test_multiprocessing_encoder` and `test_cge…
phu0ngng Sep 29, 2025
dfeef1a
[JAX] Address tolerance check for current scaling dact dbias (#2211)
jberchtold-nvidia Sep 29, 2025
3f5b475
[Core][PyTorch] NVFP4 recipe (#2177)
ksivaman Sep 29, 2025
2354fb8
Fix the segfault in the nvfp4 quantization (#2214)
ptrendx Sep 30, 2025
25252e9
[PyTorch] Add FP8 attention with current scaling (#2012)
cyanguwa Sep 30, 2025
7fa0f55
[Pytorch] Support for Swiglu Activation used in GPT OSS (#2161)
vthumbe1503 Sep 30, 2025
ce18bee
[JAX] Load modules during initialize for Norm and Act primitives (#2219)
jberchtold-nvidia Sep 30, 2025
7022d50
[PyTorch] Quantizer as API (#2039)
negvet Oct 1, 2025
ac4e0fd
[JAX] Rework amax reduction over TPSP (#2218)
phu0ngng Oct 1, 2025
b0d562d
[JAX] Fix `rng_state` shape in fused attention (#2217)
phu0ngng Oct 1, 2025
ac886c3
[PyTorch] Fix QuantizedTensorBase -> QuantizedTensorStorage (#2226)
negvet Oct 1, 2025
f0a9404
Fix hang during debug build (#2221)
ksivaman Oct 1, 2025
90449f7
Convert `NVFP4BlockScaling` to dataclass (#2227)
ksivaman Oct 1, 2025
aee5a82
Fix the cuBLAS workspace alignment (#2223)
ptrendx Oct 1, 2025
c100318
[PyTorch] Set usages for linear op quantizers before forward (#2222)
timmoon10 Oct 2, 2025
f936c2a
[JAX] Fix code block in fp8_autocast docstring (#2228)
jberchtold-nvidia Oct 2, 2025
be7f43f
[JAX] Fix shard map issue when `get_all_mesh_axes()` is used (#2229)
jberchtold-nvidia Oct 2, 2025
e30c36a
[PyTorch] fix int32 overflow in permute kernels (#2196)
hxbai Oct 2, 2025
b840898
[JAX] Clamped Swiglu Integration (#2194)
vthumbe1503 Oct 3, 2025
dfe5b7d
[Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek…
janekb04 Oct 3, 2025
5be8125
Fix bug where CUTLASS kernel was not being compiled for SM90a (#2235)
timmoon10 Oct 4, 2025
08779fd
Fix FP8 current scaling attention logic (#2234)
ksivaman Oct 4, 2025
7e45be7
Added the NVFP4 section to the low precision training tutorial (#2237)
ptrendx Oct 5, 2025
0db0f4d
[JAX] Fix for GEMM + fuse bias + AllReduce (#2230)
phu0ngng Oct 6, 2025
56e2fed
[Build] fix: TE installation failed to find uv-installed cuDNN librar…
KivenChen Oct 6, 2025
9f3e79b
[PyTorch] Fix tests for 🤗 integration (#2239)
ksivaman Oct 6, 2025
127b6d3
[JAX] Activation/Normalization to output amax for later quantization …
phu0ngng Oct 7, 2025
76bced5
`NVFP4BlockScaling` recipe docs (#2241)
ksivaman Oct 7, 2025
ac5e868
Skip fp8 tests on unsupported devices (#2243)
vcherepanov-nv Oct 7, 2025
66f9b3c
[PyTorch] Unblock fused bgrad quantization path for nvfp4 (#2246)
ksivaman Oct 8, 2025
af2a0c1
[JAX] Async issuing D2H memcpy for grouped_gemm group_sizes array (#2…
huanghua1994 Oct 8, 2025
e37e33e
Disallow pure E5M2 recipe for `Float8BlockScaling` (#2251)
ksivaman Oct 9, 2025
9bf4175
[PyTorch] Deprecate old `float8_tensor.py` (#2250)
ksivaman Oct 9, 2025
e99be1b
Update minimum python version to 3.10 and add checks in CI (#2247)
ksivaman Oct 9, 2025
8a7ab3d
[JAX] NVFP4 support in TE/JAX (#2254)
jberchtold-nvidia Oct 9, 2025
dd9433e
Don't pickle an empty dict in LayerNorm and pt base modules (#2253)
pstjohn Oct 9, 2025
7ad130e
Offloading support for multiple attention layouts (#2024)
sanandaraj5597 Oct 13, 2025
8eec200
Disable torch autocast context in rope forward pass (#2240)
pstjohn Oct 13, 2025
8c364b4
[Common][JAX] Improve error message for cublas fp8 gemm with incorrec…
jberchtold-nvidia Oct 13, 2025
76e1af3
[JAX] Add assertion message to amax -> scale computation (#2263)
jberchtold-nvidia Oct 13, 2025
a3b749b
FSDP grad fusion support (#2191)
sanandaraj5597 Oct 13, 2025
5ec0f33
[JAX] Fix test path for fp8 grouped gemm ag (#2262)
KshitijLakhani Oct 14, 2025
dfacd9f
[PyTorch] Use Quantization API for reference NVFP4 recipe (#2259)
negvet Oct 14, 2025
ca6fedc
[JAX] Add BRCM support for THD (#2242)
KshitijLakhani Oct 14, 2025
85a9199
Generalize quantization APIs for FP8/FP4/.. recipes (#2256)
ksivaman Oct 14, 2025
fd2f589
[PyTorch] Bump minimum cuDNN version for fused attention with FP8 cur…
timmoon10 Oct 14, 2025
4c572f0
[PyTorch Debug] Fix issue with start_end_list logging feature (#2252)
paul-gibbons Oct 15, 2025
88564d5
README - latest news update (#2273)
sbhavani Oct 15, 2025
452c737
Added support for DistOpt with offloading with MoE's (#2264)
sanandaraj5597 Oct 16, 2025
81c363b
[PyTorch] Add record_stream and untyped_storage func op in QuantizedT…
xiaoxi-wangfj Oct 16, 2025
5624dbb
Changed VERSION to 2.10.0.dev0
ptrendx Oct 16, 2025
9dd6192
[JAX] Fix imports in test for deprecated jax.experimental.pjit (#2274)
KshitijLakhani Oct 17, 2025
05dc1e6
NVFP4 Move RHT BLAS to GPU (#2275)
kevin-tong-augment Oct 17, 2025
bd38004
fall back after failing ldconfig-based lib loading for cuDNN (#2277)
getim Oct 17, 2025
a7a69ca
Bump up FA to 2.8.3 (#2282)
Owen1B Oct 17, 2025
c593bce
Fix test of FSDP2 by correcting init logic and applying autocast (#2105)
ntenenz Oct 17, 2025
ee384ab
Make `CanonicalizeGemmInput()` support non-TN layout FP8 GEMM on Blac…
denera Oct 17, 2025
fd234d8
Wheels for cuda 13 (#2278)
ksivaman Oct 18, 2025
dd7ab71
Fix error with triton 3.5 (#2286)
fzyzcjy Oct 20, 2025
bd55e7b
[PyTorch] Fix CI failures due to deterministic attention backend (#2288)
ksivaman Oct 20, 2025
b4a1d4d
[PyTorch][MOE] Support NVFP4 Grouped Linear (#2215)
zhongbozhu Oct 21, 2025
e90582f
[Common] Removed activations from NVFP4 quantize C++ unit tests (#2289)
Oleg-Goncharov Oct 21, 2025
ce2f9fa
[JAX] HuggingFace login in JAX examples if token is available (#2290)
jberchtold-nvidia Oct 21, 2025
2712bb9
Add post-processing API for FP8 primary weights to support CUDA Graph…
kunlunl Oct 21, 2025
ce2e8bd
[PyTorch] Decouple python quantization classes and refactor custom qu…
negvet Oct 22, 2025
818b30c
[JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quan…
jberchtold-nvidia Oct 22, 2025
2ac3c16
[JAX] Defer TE/JAX cublas shape check on fp8 gemms until lowering (#2…
jberchtold-nvidia Oct 22, 2025
66acb8e
Include TE core headers in final build (#2291)
ksivaman Oct 23, 2025
eb34783
Overhaul the compilation for the arch-specific features (#2279)
ptrendx Oct 23, 2025
e2f2a0b
[JAX] Make SR rng state always 2D (num_devices, 4) to fix partitionin…
jberchtold-nvidia Oct 23, 2025
021e1e6
[PyTorch Debug] Fix issue with microbatching + debug value caching (#…
pggPL Oct 23, 2025
6273ced
[PyTorch] Support delay_wgrad_compute cudagraph (#1948)
buptzyb Oct 24, 2025
060811c
[Common] Fix checks in quantize_transpose_vector_blockwise_fp4 (#2299)
jberchtold-nvidia Oct 24, 2025
87cb26c
[PyTorch] Add max_logit support for MuonClip (#2195)
cyanguwa Oct 25, 2025
d2945c6
[PyTorch] Use dummy wgrad in GroupedLinear (#2305)
Autumn1998 Oct 27, 2025
d7c9777
Remove `nvidia-mathdx` dependency (#2295)
ksivaman Oct 27, 2025
a019c80
Submodule checkout during setup (#2293)
ksivaman Oct 27, 2025
4cf2f12
Change the pyTorch installation to CUDA 13 in Build All GitHub action…
ptrendx Oct 27, 2025
a8e4346
[JAX] Use TE quantization when TE fused norm is disable (#2303)
phu0ngng Oct 28, 2025
c6cbcc8
[Pytorch] Integrate GPT OSS Swiglu in TransformerLayer (#2312)
vthumbe1503 Oct 29, 2025
f0295f9
CMake to respect MAX_JOBS or NVTE_MAX_JOBS (#2319)
phu0ngng Oct 30, 2025
5e8a9a9
[JAX] Fix: Skip determinism tests for bprop for all sm >=100 (#2315)
KshitijLakhani Oct 30, 2025
490a5f4
[PyTorch] Fix attention backend and tests for `sm120` (#2320)
ksivaman Oct 30, 2025
0e80c84
[Common] Split cast/gated kernels by scaling mode (#2248)
Oleg-Goncharov Oct 30, 2025
26370b1
[PyT] Bump the min version expected to supported FP8 current scaling …
KshitijLakhani Oct 30, 2025
1269b2e
[JAX] Ensure JAX reference impl uses an accurate backend in our tests…
jberchtold-nvidia Oct 30, 2025
006670d
[JAX] Fix mesh resource requirement when no mesh (#2307)
jberchtold-nvidia Oct 31, 2025
e7227af
[Common] Deleted unused header (#2324)
Oleg-Goncharov Oct 31, 2025
c57ffc5
[JAX] L1_jax_distributed_test suit with individual executions (#2321)
phu0ngng Nov 3, 2025
3d76218
[PyTorch debug] Fixes to debug tests failures (#2268)
pggPL Nov 4, 2025
77a0063
[PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137)
pggPL Nov 5, 2025
b6020e3
[JAX] Fix bug with pre scale bias (#2300)
pggPL Nov 5, 2025
dcaca2a
[JAX] Try to use pre-downloaded dataset artifacts first (#2345)
jberchtold-nvidia Nov 6, 2025
f3b97c2
Fix out of bounds access in the FP4 dequantize kernel (#2346)
ptrendx Nov 6, 2025
b14a3b6
Make FP8 weights compatible with older MCore version (#2342)
kunlunl Nov 6, 2025
4ff3eed
[JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#…
jberchtold-nvidia Nov 7, 2025
f62cad9
Fix sharding of segment position to match id in ring attention. (#2349)
mgoldfarb-nvidia Nov 7, 2025
26aad6b
Disable cuDNN attention for known IMA and NaNs (#2344)
ksivaman Nov 7, 2025
5978f1d
[JAX] Default to fused attention in JAX DPA (#2363)
KshitijLakhani Nov 7, 2025
d20311b
Update cudnn frontend to v1.16.0 (#2362)
ksivaman Nov 7, 2025
3454f84
[common] Remove kvpacked and qkvpacked attention functions for every …
pggPL Nov 7, 2025
5ea8343
Move Triton to common (#2359)
tdophung Nov 10, 2025
7a58598
[JAX] Fused layers argument default values changed (#2347)
tdophung Nov 10, 2025
29537c9
[PyTorch] FSDP2 Support for TE (#2245)
vthumbe1503 Nov 11, 2025
f8693d2
Fix CI failure related to bug in MXFP8 copy implementation (#2369)
vthumbe1503 Nov 12, 2025
e4bfa62
[Feature] Enable rope application with offsets for training (#2188)
sudhakarsingh27 Nov 12, 2025
c544ced
[JAX] Relax tolerance for the test_multiprocessing_encoder.py with NV…
phu0ngng Nov 12, 2025
d8f1e68
fix gradient accumulation fusion for FSDP (#2371)
tomlifu Nov 13, 2025
d0d4063
[PyTorch] Fix amax computation using output_t data in normalization (…
negvet Nov 13, 2025
ef28c86
[JAX] NVFP4 scale swizzling via nvte kernel (#2350)
phu0ngng Nov 13, 2025
9440b76
[JAX] Shardy rule + QuantizeLayout Rework (#2364)
phu0ngng Nov 13, 2025
67d63d0
[JAX] Support for checkpointing quantizations (#2356)
jberchtold-nvidia Nov 13, 2025
0ded113
[JAX] XLA_FLAG to WAR the current NCCL issue with test_distributed_so…
phu0ngng Nov 13, 2025
262c184
[PyTorch] Add reset cudagraph interface (#2367)
buptzyb Nov 14, 2025
b88f727
[JAX] Make all jax attention calls use non-packed common calls (#2358)
pggPL Nov 14, 2025
a075475
[JAX] Improve support and testing for direct recipe usage without aut…
jberchtold-nvidia Nov 14, 2025
c525760
[PyTorch] Activation offloading refactor (#1762)
pggPL Nov 14, 2025
389a6ba
[JAX] Use TE quant if TE fused act is disabled (#2374)
jberchtold-nvidia Nov 14, 2025
07f3c6a
[ROCm] merge NV upstream v2.10 dev without resolving conflicts
wangye805 Jan 25, 2026
b8a4024
[ROCm] resolve the conflicts in common dir
wangye805 Feb 2, 2026
0519b4b
[ROCm] resolve the conflicts on jax side
wangye805 Feb 10, 2026
8f4b04d
[ROCm] resolve the conflicts on pytorch side
wangye805 Feb 10, 2026
e60ff21
[ROCm] resolve the conflicts in setup
wangye805 Feb 10, 2026
8bbb162
[ROCm] resolve the cpp gtest
wangye805 Feb 11, 2026
f573b40
[ROCm] resolve pytorch and jax tests
alextmagro Feb 11, 2026
eaaae94
pytest, example, wheels conflict resolution
alextmagro Feb 19, 2026
8f94cf6
jax and pytorch bugfix
alextmagro Feb 24, 2026
bac7993
copyrights and fp8_autocast->autocast fix
alextmagro Feb 24, 2026
8ae38e8
Enable test_distributed_dense.py
alextmagro Feb 24, 2026
05a977a
address IFU comments
alextmagro Mar 3, 2026
0385852
_FormatHelperFP8 and missing file add
alextmagro Mar 3, 2026
46d382d
add use_async_d2h_group_size as a test parameter
alextmagro Mar 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 3 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript
run: |
pip install pybind11[global] einops onnxscript
pip install torch --index-url https://download.pytorch.org/whl/cu130
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ repos:
entry: clang-format -i
args: ["-style=file"]
files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$

- repo: https://github.com/netromdk/vermin
rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3
hooks:
- id: vermin
args: ['-t=3.10', '--violations']
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 159 files
14 changes: 11 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ Transformer Engine

Latest News
===========

* [09/2025] `Pretraining Large Language Models with NVFP4 <https://www.arxiv.org/pdf/2509.25149>`_
* [09/2025] `Native FP8 Mixed Precision Training for Ling 2.0, Open Sourced! <https://huggingface.co/blog/im0qianqian/ling-mini-2-fp8-mixed-precision-training-solution>`_
* [09/2025] `Faster Training Throughput in FP8 Precision with NVIDIA NeMo <https://developer.nvidia.com/blog/faster-training-throughput-in-fp8-precision-with-nvidia-nemo/>`_
* [08/2025] `How we built DeepL's next-generation LLMs with FP8 for training and inference <https://www.deepl.com/en/blog/tech/next-generation-llm-fp8-training>`_
* [08/2025] `NVFP4 Trains with Precision of 16-bit and Speed and Efficiency of 4-bit <https://developer.nvidia.com/blog/nvfp4-trains-with-precision-of-16-bit-and-speed-and-efficiency-of-4-bit/>`_
* [06/2025] `Floating Point 8: An Introduction to Efficient, Lower-Precision AI Training <https://developer.nvidia.com/blog/floating-point-8-an-introduction-to-efficient-lower-precision-ai-training/>`_
* [05/2025] `Advanced Optimization Strategies for LLM Training on NVIDIA Grace Hopper <https://developer.nvidia.com/blog/advanced-optimization-strategies-for-llm-training-on-nvidia-grace-hopper/>`_
* [03/2025] `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72778/>`_
* [03/2025] `Measure and Improve AI Workload Performance with NVIDIA DGX Cloud Benchmarking <https://developer.nvidia.com/blog/measure-and-improve-ai-workload-performance-with-nvidia-dgx-cloud-benchmarking/>`_

Expand Down Expand Up @@ -436,7 +444,7 @@ PyTorch
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with te.autocast(enabled=True, recipe=fp8_recipe):
out = model(inp)

loss = out.sum()
Expand Down Expand Up @@ -471,7 +479,7 @@ Flax
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with te.autocast(enabled=True, recipe=fp8_recipe):
model = te_flax.DenseGeneral(features=HIDDEN)

def loss_fn(params, other_vars, inp):
Expand Down Expand Up @@ -547,7 +555,7 @@ pip Installation
**Prerequisites for pip installation:**

* A compatible C++ compiler
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) if installing from source.

To install the latest stable version with pip:

Expand Down
152 changes: 152 additions & 0 deletions benchmarks/benchmark_rht_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext

from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer

scale_padding_to = 1
permute_scale = False

TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
}


def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
# Generate random input data
M, K = shape
x = torch.randn([M, K], dtype=input_dtype, device="cuda")

assert shape[0] % 16 == 0, "Shape must be divisible by 16"
assert shape[1] % 16 == 0, "Shape must be divisible by 16"

# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
stochastic_rounding=stochastic_rounding,
)
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, K), dtype=x.dtype, device=x.device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)

with torch.no_grad():
stmt = "kernel_func(input, output)"
globals_dict = {
"kernel_func": nvfp4_quantizer.update_quantized,
"input": x,
"output": x_nvfp4_sut,
}

timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=5)
print(timing)
timing_us = timing.median * 1e6

input_nbytes = shape[0] * shape[1] * 2 # bf16
output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems

total_nbytes = (
0
+ input_nbytes
* 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ 2 * 4 # Output 2 * float for scale & amax
+ 2 * 4 # Input 2 * float
+ output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ sf_nbytes * 2 # Scale factor
)

throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)

print(
f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
f" {throughput_GBps} GB/s"
)
return timing_us, throughput_GBps


# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
args = parser.parse_args()

if args.profile:
print("Profiling is enabled.")
else:
print("Profiling is disabled.")

shapes = [
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]

if args.profile:
shapes = [
(16384, 6144),
]

data = []
for stochastic_rounding in [True]: # , False]:
for shape in shapes:
print(
f"Running benchmark_func with shape {shape} and stochastic_rounding"
f" {stochastic_rounding}"
)
timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
data.append(
[
"benchmark_func",
shape,
stochastic_rounding,
timing_us,
throughput_GBps,
]
)

df = pd.DataFrame(
data=data,
columns=[
"kernel",
"shape",
"stochastic_rounding",
"timing_us",
"throughput(GB/s)",
],
)
print(df)
df.to_csv("benchmark_cast_nvfp4.csv", index=False)
75 changes: 52 additions & 23 deletions benchmarks/linear/benchmark_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,58 +6,69 @@
import torch
import torch.utils.benchmark as benchmark
import pandas as pd
import pathlib

from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.common.recipe import (
Float8BlockScaling,
MXFP8BlockScaling,
NVFP4BlockScaling,
)
from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager
from contextlib import nullcontext

"""
# Profile BF16 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \
--output=./benchmarks/linear/b200_numgemm_8_bf16 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16

# Profile FP8 sub-channel recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \
--output=./benchmarks/linear/h100hbm_numgemm_8_fp8_sub_channel \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel

# Profile MXFP8 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \
--output=./benchmarks/linear/b200_numgemm_8_mxfp8 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8

# Profile NVFP4 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_numgemm_8_nvfp4 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4

"""

RECIPES = {
"bf16": None,
"fp8_sub_channel": Float8BlockScaling(),
"mxfp8": MXFP8BlockScaling(),
"nvfp4": NVFP4BlockScaling(),
}

mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()


def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = (
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
quantization_context = (
autocast(enabled=True, recipe=recipe) if recipe is not None else nullcontext()
)
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")

if mode == "fwd_only":
with torch.no_grad(), fp8_context:
with torch.no_grad(), quantization_context:
for i in range(run_num_steps):
y_q = layer.forward(
x,
Expand All @@ -70,7 +81,7 @@ def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=
layer.zero_grad()
x.grad = None

with fp8_context:
with quantization_context:
for i in range(run_num_steps):
label = f"step_{i}"
torch.cuda.nvtx.range_push(label)
Expand Down Expand Up @@ -145,7 +156,7 @@ def benchmark_linear(
"recipe": recipe,
},
num_threads=1,
).blocked_autorange(min_run_time=5)
).blocked_autorange(min_run_time=10)
print(f"{recipe_name}: {timing} \n")
timing_ms = timing.median * 1000 / num_microbatches

Expand Down Expand Up @@ -228,30 +239,44 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):

use_bias = False
# Set the MKN values to benchmark
# Deepseek V3 EP64, SEQ_LEN=8192, topK8
# 256 expert => 4 local experts
# Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 16384
# M = AvgM * localExperts = 65536
# K = 7168
# N = 2048

# Deepseek V3 EP32, SEQ_LEN=8192, topK8
# 256 expert => 8 local experts
# Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 8192
# M = AvgM * localExperts = 65536
# K = 7168
# N = 2048

# 4 or 8local experts per rank
num_gemms_list = [4, 8]

# MKN for group linear
mkns = []
for m in [8192]:
# for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]:
for n in [8192]:
for k in [4096]:
for m in [65536]:
for k in [7168]:
for n in [2048]:
mkns.append((m, k, n))

# default recipes to run if not specified
recipe_list = ["bf16"]

if args.recipe == "all":
recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"]
recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"]
else:
recipe_list = [args.recipe]

num_gemms_list = [8]

if args.profile:
mkns = [(4096 * 8, 4096, 4096)]
mkns = [(8192 * 8, 7168, 2048)]
# in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, or bf16"
" fp8_sub_channel, mxfp8, nvfp4, or bf16"
)
recipe_list = [args.recipe]
num_gemms_list = [8]
Expand All @@ -268,13 +293,17 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
"bf16",
"fp8_sub_channel",
"mxfp8",
], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8"
"nvfp4",
], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4"
if recipe_name == "mxfp8" and not mxfp8_available:
print(f"MXFP8 is not available, skipping {recipe_name}")
continue
if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available:
print(f"FP8 block scaling is not available, skipping {recipe_name}")
continue
if recipe_name == "nvfp4" and not nvfp4_available:
print(f"NVFP4 is not available, skipping {recipe_name}")
continue

df = run_benchmark_linear(
mkns,
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.8.0.dev0
2.10.0.dev0
Loading