Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
scaled_init_method_normal,
)
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx
import transformer_engine_torch as tex
from transformer_engine.pytorch.quantized_tensor import (
Quantizer,
Expand Down Expand Up @@ -2581,12 +2582,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT
META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1
META_O = FP8FwdTensorIdx.GEMM2_INPUT
META_DO = FP8BwdTensorIdx.GRAD_INPUT2
META_S = FP8FwdTensorIdx.GEMM3_OUTPUT
META_DP = FP8BwdTensorIdx.GRAD_INPUT3


class _custom_mha_fp8(torch.autograd.Function):
Expand Down Expand Up @@ -2614,14 +2615,14 @@ def forward(
d = in_features // h
b = cu_seqlens.numel() - 1

input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
input_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
qkv_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT]
qkv_weight_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
o_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
dO_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
dQKV_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
s_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT2]
dP_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT3]

inp_fp8 = input_quantizer(inp)

Expand Down
21 changes: 11 additions & 10 deletions tests/pytorch/test_custom_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common import recipe
from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx
from transformer_engine.pytorch import (
autocast,
Linear,
Expand Down Expand Up @@ -169,11 +170,11 @@ def test_custom_recipe_matches_current_scaling():
with autocast(enabled=True, recipe=ref_recipe):
out_ref = model_ref(inp_ref)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
ref_fwd_in = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
ref_fwd_w = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
ref_fwd_out = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
ref_bwd_go = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
ref_bwd_gi = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3
Expand All @@ -200,11 +201,11 @@ def quantizer_factory(role):
with autocast(enabled=True, recipe=custom_recipe):
out_custom = model_custom(inp_custom)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
cus_fwd_in = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
cus_fwd_w = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
cus_fwd_out = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
cus_bwd_go = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
cus_bwd_gi = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3
Expand Down
12 changes: 8 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Set conditions for MXFP8 and NVFP4 gemm execution.
const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);
int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling
if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) {
is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
}

// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
Expand All @@ -129,7 +133,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +224,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
21 changes: 20 additions & 1 deletion transformer_engine/common/util/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include <cuda.h>

#include <mutex>
#include <string>
#include <unordered_map>

#include "../common.h"
#include "../util/string.h"
Expand All @@ -29,13 +31,30 @@ void *get_symbol(const char *symbol, int cuda_version = 12010);
* without GPUs. Indirect function calls into a lazily-initialized
* library ensures we are accessing the correct version.
*
* Symbol pointers are cached to avoid repeated lookups.
*
* \param[in] symbol Function name
* \param[in] args Function arguments
*/
template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) {
using FuncT = CUresult(ArgTs...);
FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));

static std::unordered_map<std::string, void *> symbol_cache;
static std::mutex cache_mutex;
FuncT *func;

{
std::lock_guard<std::mutex> lock(cache_mutex);
auto it = symbol_cache.find(symbol);
if (it == symbol_cache.end()) {
void *ptr = get_symbol(symbol);
symbol_cache[symbol] = ptr;
func = reinterpret_cast<FuncT *>(ptr);
} else {
func = reinterpret_cast<FuncT *>(it->second);
}
}
return (*func)(args...);
}

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/debug/pytorch/debug_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,12 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None
raise RuntimeError(
"Cannot recreate columnwise tensor from rowwise tensor is debug mode."
)

@property
def device(self):
"""Return the device of the tensor. Define this to avoid expensive PyObject lookups."""
if self.rowwise_gemm_tensor is not None:
return self.rowwise_gemm_tensor.device
if self.columnwise_gemm_tensor is not None:
return self.columnwise_gemm_tensor.device
raise RuntimeError("DebugQuantizedTensor has no data!")
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def mask_func(x, y):
bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and layer_number is not None
)

def fast_setattr(self, name: str, value: Any) -> None:
"""Fast attribute set for non-parameter fields."""
self.__dict__[name] = value

def forward(
self,
_alibi_cache: Dict[str, Any],
Expand Down
6 changes: 5 additions & 1 deletion transformer_engine/pytorch/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""Multi-head Attention."""
import os
import collections
from typing import Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch

from transformer_engine.pytorch.quantization import FP8GlobalStateManager
Expand Down Expand Up @@ -478,6 +478,10 @@ def __init__(
**common_gemm_kwargs,
)

def fast_setattr(self, name: str, value: Any) -> None:
"""Fast attribute set for non-parameter fields."""
self.__dict__[name] = value

def _create_qk_norm_modules(
self,
qk_norm_type: Optional[str],
Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""Enums for e2e transformer"""
from types import SimpleNamespace
import torch
import torch.distributed
import transformer_engine_torch as tex
Expand Down Expand Up @@ -40,6 +41,25 @@
tex.DType.kBFloat16: torch.bfloat16,
}

# Cache enum -> int conversions to avoid repeated PyObject lookups.
FP8FwdTensorIdx = SimpleNamespace(
GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT),
GEMM1_WEIGHT=int(tex.FP8FwdTensors.GEMM1_WEIGHT),
GEMM1_OUTPUT=int(tex.FP8FwdTensors.GEMM1_OUTPUT),
GEMM2_INPUT=int(tex.FP8FwdTensors.GEMM2_INPUT),
GEMM2_WEIGHT=int(tex.FP8FwdTensors.GEMM2_WEIGHT),
GEMM2_OUTPUT=int(tex.FP8FwdTensors.GEMM2_OUTPUT),
GEMM3_OUTPUT=int(tex.FP8FwdTensors.GEMM3_OUTPUT),
)
FP8BwdTensorIdx = SimpleNamespace(
GRAD_INPUT1=int(tex.FP8BwdTensors.GRAD_INPUT1),
GRAD_INPUT2=int(tex.FP8BwdTensors.GRAD_INPUT2),
GRAD_INPUT3=int(tex.FP8BwdTensors.GRAD_INPUT3),
GRAD_OUTPUT1=int(tex.FP8BwdTensors.GRAD_OUTPUT1),
GRAD_OUTPUT2=int(tex.FP8BwdTensors.GRAD_OUTPUT2),
GRAD_OUTPUT3=int(tex.FP8BwdTensors.GRAD_OUTPUT3),
)

AttnMaskTypes = (
"no_mask",
"padding",
Expand Down
13 changes: 7 additions & 6 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NVTE_Fused_Attn_Backend,
)
from ..quantized_tensor import Quantizer
from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx


__all__ = [
Expand Down Expand Up @@ -103,12 +104,12 @@
BACKEND_F16m512_FP8_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16

META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT
META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1
META_O = FP8FwdTensorIdx.GEMM2_INPUT
META_DO = FP8BwdTensorIdx.GRAD_INPUT2
META_S = FP8FwdTensorIdx.GEMM3_OUTPUT
META_DP = FP8BwdTensorIdx.GRAD_INPUT3


def fused_attn_fwd(
Expand Down
26 changes: 2 additions & 24 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,6 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
return 0.0


def get_tensor_device(tensor: torch.Tensor) -> int:
"""
Returns tensor device as an integer.

This method is used because checking instances of
QuantizedTensor or Storage incurs more CPU overhead.
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
return torch.cuda.current_device()


def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -117,7 +95,7 @@ def general_gemm(

alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False)
workspace = get_cublas_workspace(A.device.index, ub is not None, False)

if ub_type is not None:
assert ub is not None, (
Expand Down Expand Up @@ -235,7 +213,7 @@ def general_grouped_gemm(
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

sm_count = get_sm_count()
workspaces = get_cublas_workspace(get_tensor_device(A[0]), False, True)
workspaces = get_cublas_workspace(A[0].device.index, False, True)

if grad and use_bias:
grad_bias = [
Expand Down
17 changes: 8 additions & 9 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
std::once_flag extension_init_flag;
PyTypeObject *GroupedTensorStoragePythonClass = nullptr;

void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Expand All @@ -55,7 +55,6 @@ void init_float8_extension() {
}

void init_mxfp8_extension() {
if (MXFP8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
Expand All @@ -70,7 +69,6 @@ void init_mxfp8_extension() {
}

void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
Expand All @@ -91,7 +89,6 @@ void init_float8blockwise_extension() {
}

void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
Expand All @@ -116,11 +113,13 @@ void init_grouped_tensor_extension() {
}

void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
init_grouped_tensor_extension();
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
init_grouped_tensor_extension();
});
}

} // namespace transformer_engine::pytorch
Expand Down
Loading
Loading