From 9dac78e76a8e6c33add4d0b1aec8b3dd2c7db8db Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 2 Mar 2026 16:58:43 -0800 Subject: [PATCH 1/2] CPU Overhead Optimizations (#2559) * add all the optimizations Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * requires_grad optimization Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test if commenting out requires_grad works Signed-off-by: Varun Thumbe * fix minor bug Signed-off-by: Varun Thumbe * fix ci Signed-off-by: Varun Thumbe * missed a bug Signed-off-by: Varun Thumbe * Update transformer_engine/pytorch/csrc/quantizer.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * fix some bugs pointed to by copilot Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting error Signed-off-by: Varun Thumbe * fix the error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the bug Signed-off-by: Varun Thumbe * get rid of the change Signed-off-by: Varun Thumbe * fix the transpose shape bug Signed-off-by: Varun Thumbe * minor linter fix Signed-off-by: Varun Thumbe * fix lint Signed-off-by: Varun Thumbe * fix linting error Signed-off-by: Varun Thumbe * address copilot review comment regarding error check when both data and transpose are None Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linting errors Signed-off-by: Varun Thumbe * missed a merge conflict Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * final optimizations Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comment from greptile Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comment + stride optimization Signed-off-by: Varun Thumbe * address linter issue Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor lint Signed-off-by: Varun Thumbe * fix ci bug Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * another optimization to do at::native::empty_cuda directly instead of at::empty Signed-off-by: Varun Thumbe * cleanups Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * better solution for device Signed-off-by: Varun Thumbe * enum to int cache Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused function Signed-off-by: Varun Thumbe * Update transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * index instead of device bug Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci: Signed-off-by: Varun Thumbe * debug quantized tensor fix Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert cudnnt front end change Signed-off-by: Varun Thumbe --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/attention/test_attention.py | 29 +- tests/pytorch/test_custom_recipe.py | 21 +- .../common/gemm/cublaslt_gemm.cu | 12 +- transformer_engine/common/util/cuda_driver.h | 21 +- .../debug/pytorch/debug_quantization.py | 9 + transformer_engine/pytorch/constants.py | 20 ++ .../pytorch/cpp_extensions/fused_attn.py | 13 +- .../pytorch/cpp_extensions/gemm.py | 26 +- .../pytorch/csrc/extensions/pybind.cpp | 17 +- transformer_engine/pytorch/csrc/quantizer.cpp | 314 ++++++++++++++---- transformer_engine/pytorch/module/base.py | 5 +- .../pytorch/module/layernorm_linear.py | 42 +-- .../pytorch/module/layernorm_mlp.py | 62 ++-- transformer_engine/pytorch/module/linear.py | 42 +-- .../pytorch/quantized_tensor.py | 76 ++++- .../pytorch/tensor/float8_blockwise_tensor.py | 18 + .../pytorch/tensor/float8_tensor.py | 19 ++ .../pytorch/tensor/mxfp8_tensor.py | 28 ++ .../pytorch/tensor/nvfp4_tensor.py | 30 ++ .../float8_blockwise_tensor_storage.py | 9 + .../tensor/storage/float8_tensor_storage.py | 9 + .../tensor/storage/mxfp8_tensor_storage.py | 9 + .../tensor/storage/nvfp4_tensor_storage.py | 9 + 23 files changed, 622 insertions(+), 218 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 31c7041897..60ade522e3 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -44,6 +44,7 @@ scaled_init_method_normal, ) from transformer_engine.pytorch.utils import get_cudnn_version +from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx import transformer_engine_torch as tex from transformer_engine.pytorch.quantized_tensor import ( Quantizer, @@ -2581,12 +2582,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: _2X_ACC_DGRAD = False _2X_ACC_WGRAD = False -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT +META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1 +META_O = FP8FwdTensorIdx.GEMM2_INPUT +META_DO = FP8BwdTensorIdx.GRAD_INPUT2 +META_S = FP8FwdTensorIdx.GEMM3_OUTPUT +META_DP = FP8BwdTensorIdx.GRAD_INPUT3 class _custom_mha_fp8(torch.autograd.Function): @@ -2614,14 +2615,14 @@ def forward( d = in_features // h b = cu_seqlens.numel() - 1 - input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] - qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] - s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2] - dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3] + input_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] + qkv_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT] + qkv_weight_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] + o_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] + dO_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] + dQKV_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] + s_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT2] + dP_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT3] inp_fp8 = input_quantizer(inp) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 4de49115b3..536d43adc0 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -8,6 +8,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.common import recipe +from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx from transformer_engine.pytorch import ( autocast, Linear, @@ -169,11 +170,11 @@ def test_custom_recipe_matches_current_scaling(): with autocast(enabled=True, recipe=ref_recipe): out_ref = model_ref(inp_ref) # Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd) - ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + ref_fwd_in = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] + ref_fwd_w = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] + ref_fwd_out = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] + ref_bwd_go = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] + ref_bwd_gi = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3 assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3 assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3 @@ -200,11 +201,11 @@ def quantizer_factory(role): with autocast(enabled=True, recipe=custom_recipe): out_custom = model_custom(inp_custom) # Assert dtypes for custom quantizers match reference mapping - cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + cus_fwd_in = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] + cus_fwd_w = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] + cus_fwd_out = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] + cus_bwd_go = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] + cus_bwd_gi = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c58c3cb47a..144aea1a07 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -120,6 +120,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Set conditions for MXFP8 and NVFP4 gemm execution. const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling + if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) { + is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + } // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -129,7 +133,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +224,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 2715d8e4e4..16242347f1 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -9,7 +9,9 @@ #include +#include #include +#include #include "../common.h" #include "../util/string.h" @@ -29,13 +31,30 @@ void *get_symbol(const char *symbol, int cuda_version = 12010); * without GPUs. Indirect function calls into a lazily-initialized * library ensures we are accessing the correct version. * + * Symbol pointers are cached to avoid repeated lookups. + * * \param[in] symbol Function name * \param[in] args Function arguments */ template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - FuncT *func = reinterpret_cast(get_symbol(symbol)); + + static std::unordered_map symbol_cache; + static std::mutex cache_mutex; + FuncT *func; + + { + std::lock_guard lock(cache_mutex); + auto it = symbol_cache.find(symbol); + if (it == symbol_cache.end()) { + void *ptr = get_symbol(symbol); + symbol_cache[symbol] = ptr; + func = reinterpret_cast(ptr); + } else { + func = reinterpret_cast(it->second); + } + } return (*func)(args...); } diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 5624970547..57a5967079 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -697,3 +697,12 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None raise RuntimeError( "Cannot recreate columnwise tensor from rowwise tensor is debug mode." ) + + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self.rowwise_gemm_tensor is not None: + return self.rowwise_gemm_tensor.device + if self.columnwise_gemm_tensor is not None: + return self.columnwise_gemm_tensor.device + raise RuntimeError("DebugQuantizedTensor has no data!") diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 3cce4600d9..2aff4fd8e8 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Enums for e2e transformer""" +from types import SimpleNamespace import torch import torch.distributed import transformer_engine_torch as tex @@ -40,6 +41,25 @@ tex.DType.kBFloat16: torch.bfloat16, } +# Cache enum -> int conversions to avoid repeated PyObject lookups. +FP8FwdTensorIdx = SimpleNamespace( + GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT), + GEMM1_WEIGHT=int(tex.FP8FwdTensors.GEMM1_WEIGHT), + GEMM1_OUTPUT=int(tex.FP8FwdTensors.GEMM1_OUTPUT), + GEMM2_INPUT=int(tex.FP8FwdTensors.GEMM2_INPUT), + GEMM2_WEIGHT=int(tex.FP8FwdTensors.GEMM2_WEIGHT), + GEMM2_OUTPUT=int(tex.FP8FwdTensors.GEMM2_OUTPUT), + GEMM3_OUTPUT=int(tex.FP8FwdTensors.GEMM3_OUTPUT), +) +FP8BwdTensorIdx = SimpleNamespace( + GRAD_INPUT1=int(tex.FP8BwdTensors.GRAD_INPUT1), + GRAD_INPUT2=int(tex.FP8BwdTensors.GRAD_INPUT2), + GRAD_INPUT3=int(tex.FP8BwdTensors.GRAD_INPUT3), + GRAD_OUTPUT1=int(tex.FP8BwdTensors.GRAD_OUTPUT1), + GRAD_OUTPUT2=int(tex.FP8BwdTensors.GRAD_OUTPUT2), + GRAD_OUTPUT3=int(tex.FP8BwdTensors.GRAD_OUTPUT3), +) + AttnMaskTypes = ( "no_mask", "padding", diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..e9f64bb693 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -16,6 +16,7 @@ NVTE_Fused_Attn_Backend, ) from ..quantized_tensor import Quantizer +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx __all__ = [ @@ -103,12 +104,12 @@ BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 -META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 -META_O = tex.FP8FwdTensors.GEMM2_INPUT -META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_S = tex.FP8FwdTensors.GEMM3_OUTPUT -META_DP = tex.FP8BwdTensors.GRAD_INPUT3 +META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT +META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1 +META_O = FP8FwdTensorIdx.GEMM2_INPUT +META_DO = FP8BwdTensorIdx.GRAD_INPUT2 +META_S = FP8FwdTensorIdx.GEMM3_OUTPUT +META_DP = FP8BwdTensorIdx.GRAD_INPUT3 def fused_attn_fwd( diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 406e7075f7..a37f1c2d4d 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -67,28 +67,6 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def get_tensor_device(tensor: torch.Tensor) -> int: - """ - Returns tensor device as an integer. - - This method is used because checking instances of - QuantizedTensor or Storage incurs more CPU overhead. - The order of attributes checked is important to also - minimize overhead. - """ - if hasattr(tensor, "device"): - return tensor.device.index - if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: - return tensor._rowwise_data.device.index - if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: - return tensor._columnwise_data.device.index - if hasattr(tensor, "_data") and tensor._data is not None: - return tensor._data.device.index - if hasattr(tensor, "_transpose") and tensor._transpose is not None: - return tensor._transpose.device.index - return torch.cuda.current_device() - - def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -117,7 +95,7 @@ def general_gemm( alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate) - workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False) + workspace = get_cublas_workspace(A.device.index, ub is not None, False) if ub_type is not None: assert ub is not None, ( @@ -235,7 +213,7 @@ def general_grouped_gemm( out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype sm_count = get_sm_count() - workspaces = get_cublas_workspace(get_tensor_device(A[0]), False, True) + workspaces = get_cublas_workspace(A[0].device.index, False, True) if grad and use_bias: grad_bias = [ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e9683ca41e..b9fc65363d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,10 +35,10 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +std::once_flag extension_init_flag; PyTypeObject *GroupedTensorStoragePythonClass = nullptr; void init_float8_extension() { - if (Float8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); Float8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); @@ -55,7 +55,6 @@ void init_float8_extension() { } void init_mxfp8_extension() { - if (MXFP8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); MXFP8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); @@ -70,7 +69,6 @@ void init_mxfp8_extension() { } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( @@ -91,7 +89,6 @@ void init_float8blockwise_extension() { } void init_nvfp4_extensions() { - if (NVFP4TensorPythonClass) return; auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); NVFP4QuantizerClass = reinterpret_cast( PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); @@ -116,11 +113,13 @@ void init_grouped_tensor_extension() { } void init_extension() { - init_float8_extension(); - init_mxfp8_extension(); - init_float8blockwise_extension(); - init_nvfp4_extensions(); - init_grouped_tensor_extension(); + std::call_once(extension_init_flag, []() { + init_float8_extension(); + init_mxfp8_extension(); + init_float8blockwise_extension(); + init_nvfp4_extensions(); + init_grouped_tensor_extension(); + }); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index e715d8f5ba..0da5f69197 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -31,6 +31,23 @@ std::vector make_transpose_shape(const std::vector& shape) { return ret; } +/*! @brief Calculate stride from shape for contiguous tensors */ +template +std::vector stride_from_shape(const std::vector& shape) { + std::vector stride; + if (shape.empty()) { + return stride; + } + std::vector rstride; + rstride.reserve(shape.size()); + rstride.push_back(static_cast(1)); + for (size_t i = shape.size(); i > 1; --i) { + rstride.push_back(rstride.back() * shape[i - 1]); + } + stride.assign(rstride.rbegin(), rstride.rend()); + return stride; +} + /*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ template std::vector convert_shape_for_fp4(const std::vector& shape) { @@ -206,9 +223,9 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -219,7 +236,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -228,26 +245,58 @@ std::pair Float8Quantizer::create_tensor( transpose.reset(); } py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); } - + py::object scale_inv_py = py::cast(*scale_inv); + at::Device device = + with_data ? data->device() + : (with_transpose ? transpose->device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + py::tuple args(0); + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + const auto stride_int64 = stride_from_shape(shape_int64); + + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + py::tuple args(0); + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + kwargs["device"] = py::cast(device); + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -337,10 +386,10 @@ std::pair Float8Quantizer::create_grouped_tens std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -480,7 +529,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -489,13 +539,12 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); } - // Initialize scale-inverse tensor at::Tensor scale_inv_tensor; { @@ -503,23 +552,55 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - + at::Device device = + with_data ? data_tensor.device() + : (with_transpose ? transpose_tensor.device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; + py::object scale_inv_py = py::cast(scale_inv_tensor); py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + const auto stride_int64 = stride_from_shape(shape_int64); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["data"] = data_py; + kwargs["fp8_scale_inv"] = scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["data_transpose"] = transpose_py; + kwargs["quantizer"] = this->quantizer; + kwargs["device"] = py::cast(device); + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -627,10 +708,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor @@ -837,21 +918,49 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); - ret = Float8BlockwiseQTensorClass( - "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, - "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2)); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["rowwise_data"] = py::cast(data_rowwise); + kwargs["columnwise_data"] = py::cast(data_colwise); + kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); + kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); + + py::tuple args(0); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); + ret = py::reinterpret_steal(result); } else { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorPythonClass)); - ret = Float8BlockwiseQTensorClass( - "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, - "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, - "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + const auto stride_int64 = stride_from_shape(torch_shape); + kwargs["shape"] = py::cast(torch_shape); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["rowwise_data"] = py::cast(data_rowwise); + kwargs["columnwise_data"] = py::cast(data_colwise); + kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); + kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); + + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); + ret = py::reinterpret_steal(result); } return {std::move(tensor), std::move(ret)}; @@ -1198,18 +1307,49 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); - out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, - columnwise_scale_inv_py, this->dtype, this->quantizer, - with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + py::tuple args(0); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + + PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - out_py = MXFP8TensorClass( - "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + const auto stride_int64 = stride_from_shape(shape_int64); + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ MXFP8 tensor @@ -1561,19 +1701,53 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Construct Python NVFP4 tensor py::object out_py; if (internal) { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); - out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py, - columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py, - this->dtype, this->quantizer, with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["amax_rowwise"] = amax_rowwise_py; + kwargs["amax_columnwise"] = amax_columnwise_py; + kwargs["fp4_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + + py::tuple args(0); + + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); - out_py = NVFP4TensorClass( - "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales); + // Use direct C API call bypassing pybind11 overhead + py::dict kwargs; + const auto stride_int64 = stride_from_shape(shape_int64); + kwargs["shape"] = py::cast(shape_int64); + kwargs["stride"] = py::cast(stride_int64); + kwargs["dtype"] = py::cast(GetATenDType(dtype)); + kwargs["rowwise_data"] = rowwise_data_py; + kwargs["columnwise_data"] = columnwise_data_py; + kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; + kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; + kwargs["amax_rowwise"] = amax_rowwise_py; + kwargs["amax_columnwise"] = amax_columnwise_py; + kwargs["fp4_dtype"] = py::cast(this->dtype); + kwargs["quantizer"] = this->quantizer; + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + py::tuple args(0); + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), + args.ptr(), kwargs.ptr()); + if (result == nullptr) { + PyErr_Print(); + } + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ tensor diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4858383c26..9c21141a39 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -929,12 +929,11 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: if torch.is_autocast_enabled(): self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return - + dtype = inp.dtype # All checks after this have already been performed once, thus skip - if self.activation_dtype == inp.dtype: + if self.activation_dtype == dtype: return - dtype = inp.dtype if not self.allow_different_data_and_param_types: for name, param in self.named_parameters(): if param is not None: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 27632db15b..ce0581024a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -52,7 +52,7 @@ _fsdp_scatter_tensors, _fsdp_gather_tensors, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ._common import apply_normalization, noop_cat, WeightGradStore @@ -1357,7 +1357,7 @@ def __init__( torch.nn.Parameter(weight_tensor[split_start:split_end]), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT, ) # Construct bias parameters if needed @@ -1615,20 +1615,20 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None - input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] input_quantizer.internal = True if not (self.parallel_mode == "column" and self.sequence_parallel): input_quantizer.optimize_for_gemm = True (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] if is_grad_enabled: - grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] grad_output_quantizer.internal = True if not (self.parallel_mode == "row" and self.sequence_parallel): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] return ( input_quantizer, @@ -1725,43 +1725,43 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe if fwd: # set configs about amax epsilon and power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # also set weight quantizer with same amax_epsilon & power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # parallel related if self.sequence_parallel and self.parallel_mode == "column": # set input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon # parallel related if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: @@ -1771,19 +1771,19 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: if self.sequence_parallel and self.parallel_mode == "column": # set input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: @@ -1807,6 +1807,6 @@ def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: return [None] - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] weight_quantizer.internal = True return [weight_quantizer] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b8823e46ca..16e620fd94 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -59,7 +59,7 @@ _get_cuda_rng_state, _set_cuda_rng_state, ) -from ..constants import dist_group_type +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..tensor.float8_tensor import ( @@ -1909,7 +1909,7 @@ def __init__( fc1_weight, init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT, ) if self.use_bias: @@ -1929,7 +1929,7 @@ def __init__( fc2_weight, init_fn=output_layer_init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM2_WEIGHT, ) if self.use_bias: @@ -2201,11 +2201,11 @@ def _get_quantizers(self, fp8_output, is_grad_enabled): ) = [None] * 10 fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() if self.fp8 or self.fp8_calibration: - fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + fc1_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] fc1_input_quantizer.internal = True if not self.sequence_parallel: fc1_input_quantizer.optimize_for_gemm = True - fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + fc2_input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, columnwise=isinstance( @@ -2216,18 +2216,16 @@ def _get_quantizers(self, fp8_output, is_grad_enabled): fc2_input_quantizer.internal = True fc2_input_quantizer.optimize_for_gemm = True if fp8_output: - fc2_output_quantizer = self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_OUTPUT - ] + fc2_output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_OUTPUT] if is_grad_enabled: fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ] fc2_grad_output_quantizer.internal = True if not self.sequence_parallel: fc2_grad_output_quantizer.optimize_for_gemm = True fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ] fc1_grad_output_quantizer.internal = True fc1_grad_output_quantizer.optimize_for_gemm = True @@ -2389,63 +2387,63 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe if fwd: # fc1_input_quantizer: set configs about amax epsilon and power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # fc2_input_quantizer self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_INPUT + FP8FwdTensorIdx.GEMM2_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_INPUT + FP8FwdTensorIdx.GEMM2_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # fc1_weight_quantizer: also set numerical configs about weight self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # fc2_weight_quantizer self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_WEIGHT + FP8FwdTensorIdx.GEMM2_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM2_WEIGHT + FP8FwdTensorIdx.GEMM2_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # parallel related if self.sequence_parallel and self.set_parallel_mode: # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon if self.sequence_parallel and self.set_parallel_mode: # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: @@ -2455,19 +2453,19 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: if self.sequence_parallel and self.set_parallel_mode: # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: if self.sequence_parallel and self.set_parallel_mode: # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT2 + FP8BwdTensorIdx.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: @@ -2478,9 +2476,9 @@ def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: return [None, None] - fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + fc1_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True - fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] + fc2_weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True return [fc1_weight_quantizer, fc2_weight_quantizer] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a55429d33d..31dac4d329 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -54,7 +54,7 @@ from ..cpp_extensions import ( general_gemm, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..quantized_tensor import ( @@ -1272,7 +1272,7 @@ def __init__( torch.nn.Parameter(weight_tensor[split_start:split_end]), init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + fp8_meta_index=FP8FwdTensorIdx.GEMM1_WEIGHT, ) # Construct bias parameters if needed @@ -1483,20 +1483,20 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None - input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT] input_quantizer.internal = True if not (self.parallel_mode == "column" and self.sequence_parallel): input_quantizer.optimize_for_gemm = True (weight_quantizer,) = self._get_weight_quantizers() if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + output_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] if is_grad_enabled: - grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] grad_output_quantizer.internal = True if not (self.parallel_mode == "row" and self.sequence_parallel): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] return ( input_quantizer, weight_quantizer, @@ -1601,43 +1601,43 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe if fwd: # set configs about amax epsilon and power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon # also set weight quantizer with same amax_epsilon & power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_WEIGHT + FP8FwdTensorIdx.GEMM1_WEIGHT ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon # paralle related if self.sequence_parallel and self.parallel_mode == "column": # customize input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: # set grad_output_quantizer with amax epsilon and power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon # parallel related if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: @@ -1647,25 +1647,25 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: if self.sequence_parallel and self.parallel_mode == "column": # customize input_quantizer with amax reduction TP group self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].with_amax_reduction = True self.quantizers["scaling_fwd"][ - tex.FP8FwdTensors.GEMM1_INPUT + FP8FwdTensorIdx.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: if self.sequence_parallel and self.parallel_mode == "row": # customize grad_output_quantizer with amax reduction TP group self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].with_amax_reduction = True self.quantizers["scaling_bwd"][ - tex.FP8BwdTensors.GRAD_OUTPUT1 + FP8BwdTensorIdx.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: return [None] - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer = self.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT] weight_quantizer.internal = True return [weight_quantizer] diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d78677bc83..cb697bc197 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -369,9 +369,13 @@ def __new__( *, requires_grad: bool = False, device: Optional[torch.device] = None, + stride: Optional[Iterable[int]] = None, ): - # We are assuming only contiguous tensors - stride = _stride_from_shape(shape) + # For stride, We are assuming only contiguous tensors + # Calculate stride from shape if not provided. When creating this object from + # C++ code, we provide the stride computed from shape in C++ to avoid the + # PyobjectVectorCall overhead of calling _stride_from_shape from C++ to Python. + stride = _stride_from_shape(shape) if stride is None else stride instance = torch.Tensor._make_wrapper_subclass( cls, shape, @@ -382,9 +386,75 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - + instance._requires_grad = requires_grad + instance._dtype = dtype return instance + @property + def dtype(self) -> torch.dtype: + """ + Return the high precision data type of the tensor + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since dtype for a tensor is never + change after creation, we cache it in a member variable and return + """ + # Lazy initialization for tensors created via alternate paths + if not hasattr(self, "_dtype"): + # pylint: disable=unnecessary-dunder-call + self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) + return self._dtype + + @dtype.setter + def dtype(self, value: torch.dtype) -> None: + """Set dtype property""" + self._dtype = value + + @property + def requires_grad(self) -> bool: + """ + Return whether or not the tensor requires gradient. + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since requires_grad is set during + initialization and may be updated, we cache it in a member variable. + """ + # Fallback to parent if not cached yet + if not hasattr(self, "_requires_grad"): + # pylint: disable=unnecessary-dunder-call + self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + """Set requires_grad property so that autograd engine is aware of the change""" + # Update the cached value and call parent class method to ensure autograd engine is aware + self.requires_grad_(value) + + def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + """Cache requires_grad property and call parent class method""" + # pylint: disable=missing-function-docstring + # Update the cached value + self._requires_grad = requires_grad + # Call parent class method to ensure autograd engine is aware + super().requires_grad_(requires_grad) + return self + + def _get_data(self) -> torch.Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + Updates the underlying tensor data and syncs the dtype cache. + """ + # Update the parent class's data descriptor + # pylint: disable=unnecessary-dunder-call + super(QuantizedTensor, type(self)).data.__set__(self, tensor) + # Update the dtype cache + self._dtype = tensor.dtype + + # Create the data property with getter and setter + data = property(_get_data, _set_data) + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ecafb6ddfc..a3d49ea4e9 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -567,6 +567,24 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): # Cast to FP8 when setting Float8BlockwiseQTensor.data data = property(_get_data, _set_data) + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.shape + if self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("Float8BlockwiseQTensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("Float8BlockwiseQTensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 55bca49af3..f66e88740f 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -926,6 +926,25 @@ def fsdp_post_all_gather( ) return out, all_gather_outputs + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._data is not None: + return self._data.shape + if self._transpose is not None: + transpose_shape = self._transpose.shape + return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) + raise RuntimeError("Both data and transpose are None") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._data is not None: + return self._data.is_cuda + if self._transpose is not None: + return self._transpose.is_cuda + raise RuntimeError("Both data and transpose are None") + @classmethod def _make_in_reduce_ex( cls, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 41d6c87f2b..96b6a67ea8 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -842,6 +842,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer.copy() @@ -861,6 +862,33 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Cast to FP8 when setting MXFP8Tensor.data data = property(_get_data, _set_data) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("MXFP8Tensor has no data!") + + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.shape + if self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("MXFP8Tensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("MXFP8Tensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 66f986a900..a8148b5752 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -700,6 +700,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer @@ -719,6 +720,35 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Cast to FP8 when setting NVFP4Tensor.data data = property(_get_data, _set_data) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("NVFP4Tensor has no data!") + + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + byte_shape = self._rowwise_data.shape + return torch.Size(byte_shape[:-1] + (byte_shape[-1] * 2,)) + if self._columnwise_data is not None: + byte_shape = self._columnwise_data.shape + return torch.Size(byte_shape[1:-1] + (byte_shape[-1] * 2, byte_shape[0])) + raise RuntimeError("NVFP4Tensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("NVFP4Tensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 4cd6d19cd8..2a86717017 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -290,6 +290,15 @@ def size(self, *args, **kwargs): reordered.append(dims[0]) return torch.Size(reordered) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("Float8BlockwiseQTensorStorage has no data!") + def _create_columnwise(self): """ Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling. diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 9adb86c453..a815b366b2 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -170,6 +170,15 @@ def size(self, *args, **kwargs): size = self._transpose.size(*args, **kwargs) return torch.Size([size[-1], math.prod(size[:-1])]) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._data is not None: + return self._data.device + if self._transpose is not None: + return self._transpose.device + raise RuntimeError("Float8TensorStorage has no data!") + def view(self, shape: torch.Size): # pylint: disable=missing-function-docstring out_data = self._data.view(shape) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 5c8510488f..12757aa58c 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -185,6 +185,15 @@ def size(self, *args, **kwargs): return self._rowwise_data.size(*args, **kwargs) return self._columnwise_data.size(*args, **kwargs) + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("MXFP8TensorStorage has no data!") + def view(self, shape: torch.Size): # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 8be23d0c19..36bf208bcd 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -228,6 +228,15 @@ def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: return torch.Size(shape) return shape[dim] + @property + def device(self): + """Return the device of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.device + if self._columnwise_data is not None: + return self._columnwise_data.device + raise RuntimeError("NVFP4TensorStorage has no data!") + def view(self, shape: torch.Size): # pylint: disable=missing-function-docstring From c68ec3101d0dc16fe6eb40294a5fed3a9370b6a8 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 2 Mar 2026 18:16:53 -0800 Subject: [PATCH 2/2] Add fast_set_attr to modules not inheriting from base.py (#2724) fix fast_set_attr in other nn modules for fsdp Signed-off-by: Varun Thumbe --- .../pytorch/attention/dot_product_attention/backends.py | 4 ++++ .../pytorch/attention/multi_head_attention.py | 6 +++++- transformer_engine/pytorch/module/layernorm.py | 6 +++++- transformer_engine/pytorch/module/rmsnorm.py | 6 +++++- transformer_engine/pytorch/transformer.py | 6 +++++- 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index aa6c063951..a6a8b0b26a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -293,6 +293,10 @@ def mask_func(x, y): bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def forward( self, _alibi_cache: Dict[str, Any], diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..5c581849e6 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -5,7 +5,7 @@ """Multi-head Attention.""" import os import collections -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from transformer_engine.pytorch.quantization import FP8GlobalStateManager @@ -478,6 +478,10 @@ def __init__( **common_gemm_kwargs, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def _create_qk_norm_modules( self, qk_norm_type: Optional[str], diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index d4f0a78ba2..54fad8d1bc 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -4,7 +4,7 @@ """LayerNorm API""" import warnings -from typing import Iterable, Optional, Union +from typing import Any, Iterable, Optional, Union import torch @@ -102,6 +102,10 @@ def __init__( **kwargs, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index ace4be31de..f8d5aade5c 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -4,7 +4,7 @@ """RMSNorm API""" import warnings -from typing import Iterable, Optional, Union +from typing import Any, Iterable, Optional, Union import torch @@ -106,6 +106,10 @@ def __init__( **kwargs, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def reset_rms_norm_parameters(self) -> None: """Deprecated""" warnings.warn( diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index cf7ce5e1a4..868cbbdac8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -6,7 +6,7 @@ import os import warnings from contextlib import nullcontext -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch @@ -545,6 +545,10 @@ def __init__( device=device, ) + def fast_setattr(self, name: str, value: Any) -> None: + """Fast attribute set for non-parameter fields.""" + self.__dict__[name] = value + def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ Set the tensor parallel group for the given