diff --git a/README.rst b/README.rst index 66d1b0b3e..0f29d8eef 100644 --- a/README.rst +++ b/README.rst @@ -354,6 +354,19 @@ legacy single-stage atomic kernel by setting: NVTE_USE_ATOMIC_AMAX=1 +Grouped GEMM using CK_Tile +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Transformer Engine provides a CK_Tile–based implementation of grouped GEMM +as an alternative to the hipBlasLt-based default grouped GEMM implementation. +This will provide performance improvements in most supported cases. + +You can enable the CK_Tile-based backend using the same environment variables as in the +upstream CUTLASS implementation: + + NVTE_USE_CUTLASS_GROUPED_GEMM=1 # Enable CK_Tile-based grouped GEMM + NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK=1 # Print a warning if falling back to hipBlasLt backend (e.g., due to an unsupported config) + Transformer Engine ****************** diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index bc29d29e3..4be6e69e7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -148,7 +148,7 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper -if torch.cuda.get_device_capability() == (9, 0): +if torch.cuda.get_device_capability() == (9, 0) or IS_HIP_EXTENSION: use_cutlass_grouped_gemm.append(True) @@ -1386,7 +1386,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") te_linear_ref = Linear( config.hidden_size, @@ -1678,7 +1678,7 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ): if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") config = model_configs[model] ln_linear_ref = LayerNormLinear( @@ -1892,7 +1892,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") ln_mlp = LayerNormMLP( hidden_size=config.hidden_size, @@ -2042,7 +2042,7 @@ def test_grouped_linear_accuracy( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") @@ -2121,6 +2121,8 @@ def test_grouped_linear_accuracy( atol, rtol = 0, 0 if use_cutlass: atol, rtol = 1e-3, 1e-3 + if IS_HIP_EXTENSION: + atol, rtol = 1e-3, 8e-3 if use_triton: atol, rtol = get_tolerances(dtype) if dtype == torch.float32: @@ -2131,7 +2133,7 @@ def test_grouped_linear_accuracy( @pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0), + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, reason="Only enable CUTLASS grouped gemm on Hopper", ) @pytest.mark.parametrize("dtype", param_types, ids=str) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 50dcf90a0..d5aab2cda 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -202,6 +202,7 @@ else() fused_attn_rocm/fused_attn_ck.cpp fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu + gemm/ck_grouped_gemm.cpp amd_detail/system.cpp) # process source code files @@ -250,6 +251,9 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) else() message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") endif() +else() + set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) endif() #USE_CUDA # Configure dependencies diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp new file mode 100644 index 000000000..61c08f0de --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -0,0 +1,307 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include + +#include +#include "../common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + // Require at least a matrix (rank >= 2). Higher ranks are flattened. + if (t.shape().size() < 2) + return false; + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; // rowwise data view +} + +// Primus-Turbo-like FP16/BF16 tile configs +// Selection rule: +// if (N % 256 == 0) use 256x256x64 +// else if (N % 128 == 0) use 256x128x64 +// else use 256x128x64 with N padding enabled +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; +}; + +// This class instantiates CK_Tile's grouped GEMM pipeline. +// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. +template +struct Runner{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::GroupedGemmKernel; +}; + +template +static bool run_grouped_impl(const NVTETensor* A_use, + const NVTETensor* B_use, + NVTETensor* D, + int group_num, + bool transA_use, + bool transB_use, + void* workspace, + size_t workspace_bytes, + hipStream_t stream) +{ + using Kernel = typename Runner::Kernel; + + const size_t needed = Kernel::GetWorkSpaceSize(group_num); + if (!workspace || workspace_bytes < needed) { + NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); + return false; + } + + thread_local std::vector> descs; + descs.clear(); + descs.reserve(group_num); + + for (int i = 0; i < group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(A_use[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(B_use[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); + return false; + } + + const int64_t M = transA_use ? Ad1 : Ad0; + const int64_t K = transA_use ? Ad0 : Ad1; + const int64_t N = transB_use ? Bd0 : Bd1; + const int64_t Kb = transB_use ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + return false; + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + return false; + } + + // Leading dimensions under the flattened-contiguous interpretation + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + descs.emplace_back( + a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); + return false; + } + + HIP_CHECK_ERROR(hipMemcpyAsync(workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + stream)); + + const ck_tile::stream_config s{stream}; + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, + ck_tile::make_kernel<1>( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(workspace), + group_num)); + return true; +} + +} // namespace grouped_gemm +} // namespace transformer_engine + +bool ck_tile_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream) +{ + if (group_num <= 0) + return true; + + using namespace transformer_engine; + using namespace transformer_engine::grouped_gemm; + + // Workspace pointer + bytes + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); + } + + // Normalize similar to upstream + // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 + // I.e., swap A and B, as well as transa and transb. + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; + const bool transA_use = transB; + const bool transB_use = transA; + + const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); + + // Get N from D[0] (assume uniform N across groups) + int64_t ref_d0 = 0, ref_d1 = 0; + Tensor* D0_te = convertNVTETensorCheck(D[0]); + if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); + return false; + } + const ck_tile::index_t N = static_cast(ref_d1); + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { + using T = typename TETypeToCKType::type; + + auto run_with_tilecfg = [&](auto tile_tag) -> bool { + using TileCfgSel = decltype(tile_tag); + + TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { + using BLayout = std::conditional_t; + + if (accumulate) { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } else { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } + }); + }); + }; + + // Select tile config like Primus-Turbo for FP16/BF16: + // N%256 -> 256x256x64 + // N%128 -> 256x128x64 + // else -> 256x128x64 padding + // NOTE: We assume N is uniform across groups. + if ((N % 256) == 0) { + return run_with_tilecfg(TileCfg_256x256x64{}); + } else if ((N % 128) == 0) { + return run_with_tilecfg(TileCfg_256x128x64{}); + } else { + return run_with_tilecfg(TileCfg_256x128x64_padding{}); + } + }); +} diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm.h new file mode 100644 index 000000000..97b4cfd88 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm.h @@ -0,0 +1,15 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +bool ck_tile_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9c2ca9b4c..cbed586ca 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -24,8 +24,11 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #ifndef __HIP_PLATFORM_AMD__ #include "cutlass_grouped_gemm.cuh" +#else +#include "ck_grouped_gemm.h" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -788,11 +791,12 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor NVTE_API_CALL(nvte_multi_tensor_gemm); #ifdef __HIP_PLATFORM_AMD__ - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, - workspace, accumulate, use_split_accumulator, math_sm_count, stream); + if (num_gemms <= 0) + return; #else const int current_device = transformer_engine::cuda::current_device(); const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); +#endif const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); const bool warn_fallback = transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); @@ -802,8 +806,13 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor workspace, accumulate, use_split_accumulator, math_sm_count, stream); }; +#ifdef __HIP_PLATFORM_AMD__ + // FIXME: The accumulate path is currently disabled due to instability on MI325. + if (!use_cutlass || num_gemms == 1 || accumulate == true) { +#else // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { +#endif cublas_path(); return; } @@ -816,6 +825,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor return true; }; +#ifndef __HIP_PLATFORM_AMD__ auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { int64_t ref_k = -1; for (size_t i = 0; i < num_gemms; i++) { @@ -832,17 +842,27 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor return true; }; +#endif auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); +#ifdef __HIP_PLATFORM_AMD__ + auto A_dt = inputA->data.dtype; + auto B_dt = inputB->data.dtype; + auto D_dt = OutputD->data.dtype; + return (A_dt == B_dt) && (A_dt == D_dt) && + (A_dt == transformer_engine::DType::kFloat16 || + A_dt == transformer_engine::DType::kBFloat16); +#else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); auto D_type = get_cuda_dtype(OutputD->data.dtype); return (A_type == B_type) && (A_type == D_type) && ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); +#endif }; // CUTLASS Grouped GEMM fast path (SM90/TMA) @@ -855,14 +875,23 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor // // Otherwise, fall back to cuBLAS. if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && +#ifdef __HIP_PLATFORM_AMD__ + true) { + if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); + } +#else all_groups_uniform_k128(B, transb)) { cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, current_device, math_sm_count, stream); +#endif } else { if (warn_fallback) { NVTE_WARN("Fallback to cuBLAS grouped GEMM."); } cublas_path(); } -#endif // __HIP_PLATFORM_AMD__ }