From e8c5cfc886beff082974a26576648bfc73b82708 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 2 Jan 2026 12:23:46 -0800 Subject: [PATCH 01/22] first draft Signed-off-by: Zhongbo Zhu --- transformer_engine/common/CMakeLists.txt | 1 + ...cast_col_hadamard_transform_cast_fusion.cu | 1333 +++++++++++++++++ .../transformer_engine/hadamard_transform.h | 16 + 3 files changed, 1350 insertions(+) create mode 100644 transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 4579c51e9f..3e6eee6d86 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -177,6 +177,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources hadamard_transform/graph_safe_group_hadamard_transform.cu hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu + hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu new file mode 100644 index 0000000000..14aa4e0cdb --- /dev/null +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -0,0 +1,1333 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + #include + + #include "common/common.h" + #include "common/util/cuda_runtime.h" + #include "common/util/curanddx.hpp" + #include "common/util/ptx.cuh" + #include "common/utils.cuh" + #include "customized_pipeline.cuh" + #include "cutlass/arch/barrier.h" + #include "cutlass/arch/reg_reconfig.h" + #include "cutlass/cluster_launch.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/detail/sm100_blockscaled_layout.hpp" + #include "cutlass/fast_math.h" + #include "cutlass/float8.h" + #include "cutlass/float_subbyte.h" + #include "cutlass/gemm/collective/builders/sm100_common.inl" + #include "cutlass/numeric_conversion.h" + #include "cutlass/numeric_types.h" + #include "cutlass/pipeline/pipeline.hpp" + #include "cutlass/platform/platform.h" + #include "cutlass/util/GPU_Clock.hpp" + #include "cutlass/util/command_line.h" + #include "cutlass/util/print_error.hpp" + +// clang-format off + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; + +// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor +using cute::Tensor; + +struct CLCResponse { uint32_t data[4] = {0}; }; + + +CUTLASS_DEVICE +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + auto output_ptr = reinterpret_cast(&output); + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" + "}" + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), + "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = reinterpret_cast const *>(&rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template < + class ElementA, + class ElementB, + class ASmemLayout, + class BSmemLayout, + class ClusterShape, + int AccumulatorPipelineStageCount_, + int EpilogueUnrollFactor_, + int SchedulerPipelineStageCount_> +struct SharedStorage { + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::detail::CustomizedPipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) CLCResponse clc_response[SchedulerPipelineStageCount_]; + uint32_t tmem_base_ptr; +}; + +template +__launch_bounds__(512, 1) +__global__ static void row_col_rht_gemm_device( + MShape M, + NShape N, + KShape K, + ClusterShape cluster_shape, + ClusterTileShape cluster_tile, + TA const* A, + AStride dA, + ASmemLayout sAlayout, + CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const* B, + BStride dB, + BSmemLayout sBlayout, + CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TD* D, + DStride dD, + DSmemLayout, + TSFD* SFD, + TSFDLayout sfd_layout, + TQA* QA, + QAStride dQA, + TSFA* SFA, + TSFALayout sfa_layout, + TiledMMA mma, + float const* a_global_amax, + float const* c_global_amax, + const size_t* rng_state) { + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kEnableFastMath = kEnableFastMath_; + static int constexpr RhtTensorSize = 16; + static int constexpr kTmaRhtTensorTransactionBytes = cutlass::bits_to_bytes( + RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::detail::CustomizedPipelineTmaUmmaAsync< + MainloopPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1,_1,_1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128,_16,_128>{}; + auto epilogue_tiler = Shape<_128,_128,_128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = ceil_div(min(N, K), size<2>(epilogue_tiler)); + + struct TileScheduler { + struct WorkTileInfo { + uint32_t m_idx = 0; + uint32_t n_idx = 0; + uint32_t l_idx = 0; + bool is_valid_tile = false; + }; + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + + int k_tile_max = 0; + + int wave_cnt = 0; + WorkTileInfo work_tile_info; + WorkTileInfo next_work_tile_info; + CLCResponse* clc_response_ptr_; + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, CLCResponse* clc_response_ptr) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + + k_tile_max(kmax), + work_tile_info({blockIdx.x, blockIdx.y, blockIdx.z, blockIdx.x( + &clc_response_ptr[state.index()])); + asm volatile( + "{\n\t" + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n\t" + "}\n" + : + : "r"(result_addr), "r"(mbarrier_addr)); + #else + CUTLASS_NOT_IMPLEMENTED(); + #endif + } + CUTLASS_DEVICE + static WorkTileInfo + work_tile_info_from_clc_response(uint32_t result_addr) { + WorkTileInfo work_tile_info; + uint32_t valid = 0; + #if defined(CUTLASS_ARCH_CLC_ENABLED) + asm volatile( + "{\n" + ".reg .pred p1;\n\t" + ".reg .b128 clc_result;\n\t" + "ld.shared.b128 clc_result, [%4];\n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" + "selp.u32 %3, 1, 0, p1;\n\t" + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {%0, %1, %2, _}, clc_result;\n\t" + "}\n" + : "=r"(work_tile_info.m_idx), "=r"(work_tile_info.n_idx), "=r"(work_tile_info.l_idx), "=r"(valid) + : "r"(result_addr) + : "memory" + ); + + cutlass::arch::fence_view_async_shared(); + #else + CUTLASS_NOT_IMPLEMENTED(); + #endif + work_tile_info.is_valid_tile = (valid == 1); + return work_tile_info; + } + }; + + + + // Allocate SMEMork + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(N, size<2>(epilogue_tiler)))); + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, shared_storage.clc_response); + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto acc_shape_mma = make_shape(take<0,2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0,2>(epilogue_tiler), _1{}, _1{}); + + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant? 32: 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant? 1: 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + if (is_epilogue_col_quant_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(c_global_amax)); + } + if (is_epilogue_row_quant_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(a_global_amax)); + } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + MainloopPipeline mainloop_pipeline( + shared_storage.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + using IsInitAccumulatorPipeline = cute::conditional_t; + AccumulatorPipeline accumulator_pipeline( + shared_storage.accumulator, + accumulator_pipeline_params, + cluster_shape, + IsInitAccumulatorPipeline{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (is_sched_warp) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + clc_pipeline_params.transaction_bytes = sizeof(CLCResponse); + clc_pipeline_params.initializing_warp = 3; + CLCPipeline clc_pipeline(shared_storage.clc, clc_pipeline_params, cluster_shape); + CLCPipelineState clc_pipeline_consumer_state; + CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (is_dma_warp) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 4; + + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + if (is_dma_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = tma_partition( + tma_load_a, + get<2>(cta_coord_vmnk), + make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsA), + group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition( + tma_load_b, + get<1>(cta_coord_vmnk), + make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(tCsB), + group_modes<0,3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + } + } + + do { + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_,scheduler.tile_m(),_); + int k_tile = 0; + // Throttle CLC producer + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier( + mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy( + tma_load_a.with(*tma_barrier, tma_mcast_mask_a), + tAgA_mk(_,k_tile_idx_n), + tAsA(_,write_stage)); + } + } + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + else if (is_mma_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = mainloop_pipeline.consumer_try_wait( + mainloop_pipe_consumer_state, + skip_wait); + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_,_,_,read_stage); + auto tCrB_nk = tCrB(_,_,0,0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) + { + int accumulator_k_block = accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_,_,_,accumulator_k_block + i); + gemm(mma, tCrA_mk(_,_,tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = mainloop_pipeline.consumer_try_wait( + mainloop_pipe_consumer_state, + skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + else if(is_sched_warp) { + cutlass::arch::warpgroup_reg_dealloc<32>(); + do { + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + clc_pipeline_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipeline_producer_state); + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + else if (is_epilogue_col_quant_warp) { + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + float const c_global_amax_val = *c_global_amax; + auto acc_epilogue_pipelined_shape = append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride( + stride<0>(bulk_tmem_mma), + Int<0>{}, + Int<0>{}, + size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // leveraging 256-bit writes to global memory + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + + size_t rng_seed = 0; + size_t rng_offset = 0; + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + + Tensor mD = make_tensor( + cute::subbyte_iterator(D), + make_shape(M,N), + dD); // (M,N) + Tensor gD_mn = local_tile( + mD, + epilogue_tiler, + make_coord(_,_, _), + Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Tensor pD = make_identity_tensor(mD.shape()); + Tensor pD_mn = local_tile( + pD, + epilogue_tiler, + make_coord(_,_, _), + Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Tensor mSFD = make_tensor(make_gmem_ptr(SFD), sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Tensor pSFD = make_identity_tensor(mSFD.shape()); + Tensor pSFD_mn = local_tile(pSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0,2>(epilogue_tiler)); + Tensor pD_mn_view = tiled_divide(pD_mn, take<0,2>(epilogue_tiler)); + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); + auto tiled_r2g = make_tiled_copy_D( + Copy_Atom{}, + tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + float const fp4_max_inv = 1.0f / fp4_max; + float const global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float const global_decode_scale = 1.0f / global_encode_scale; + float const global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + auto sfc_converter = cutlass::NumericConverter{}; + + do { + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ++k_tile) { + Tensor tDgD_mn = gD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + Tensor tDgSFD_mn = gSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + Tensor tDpD_mn = pD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + Tensor tDpSFD_mn = pSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDpD = thr_t2r.partition_D(tDpD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tTR_rAcc = make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + Tensor pSrc = thr_r2g.retile_D(tDpD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), + make_layout( + make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDpSFD_view = make_tensor( + tDpSFD_mn.data(), + make_layout( + make_shape(shape(tDpSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDpSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + Tensor tDpSFD = filter(thr_t2r.partition_D(tDpSFD_view)); + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}(tD_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, + global_decode_scale); + + cutlass::Array acc_scales; + if constexpr (kEnableFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + } + + uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + // "Prefetch" a stochastic rounding state for the first tile + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], + cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter(cutlass::multiplies>{}(compute_frgs[v], acc_scale), *reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], + acc_scale)); + } + + } + + copy_if(tiled_r2g, + [&](auto coord){ + Tensor pSrc_view = group_modes<1,rank(pSrc)>(pSrc); + return elem_less(pSrc_view(_0{},coord), shape(mD)); + }, + src, dst); + // 32bit vectorization copy 4 e4m3 SFD for per 64/(16,4):(0, 1) element + constexpr int vec_len = 32 / sizeof_bits_v; + Tensor tDrSFD_v = recast>(tDrSFD); + Tensor tDgSFD_v = recast>(tDgSFD); + copy_if( + [&](auto coord){ + Tensor tDpSFD_view = group_modes<1,rank(tDpSFD)>(tDpSFD); + return elem_less(tDpSFD_view(_0{}, coord * vec_len), shape(mSFD)); + }, + tDrSFD_v, tDgSFD_v); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } + else if (is_epilogue_row_quant_warp) { + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + float const a_global_amax_val = *a_global_amax; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + Tensor mQA = make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + Tensor pQA = make_identity_tensor(mQA.shape()); + Tensor pQA_mn = local_tile(pQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); // (BLK_M,BLK_N) + Tensor pSFA = make_identity_tensor(mSFA.shape()); + Tensor pSFA_mn = local_tile(pSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + Tensor sA = as_position_independent_swizzle_tensor( + group_modes<0,2>(coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy_D(R2GAtomQA{}, tiled_s2r); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + Tensor tQArA = make_tensor_like(make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + // Tensor tQArA_PI = thr_s2r.partition_S(sA_PI); + Tensor tQAgQA = thr_r2g_QA.partition_D(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + Tensor tQApQA = thr_r2g_QA.partition_D(pQA_mn); + + Tensor tQAgSFA = thr_s2r.partition_D(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + Tensor tQApSFA = thr_s2r.partition_D(pSFA_mn); + + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + float const fp4_max_inv = 1.0f / fp4_max; + float const global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float const global_decode_scale = 1.0f / global_encode_scale; + float const global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + auto sfa_converter = cutlass::NumericConverter{}; + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ) { + auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQApSFA_mn = tQApSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQApQA_mn = tQApQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait( + mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size(tQArA)/VectorSize; v++) { + auto compute_frgs_up = cutlass::NumericArrayConverter{}(compute_frgs[v]); + auto amax = amax_reduction(ElementAccumulator(0), compute_frgs_up); + auto pvscales= cutlass::multiplies{}(amax, global_encode_scale_multiplier); + filter(tQArSFA)(v) = sfa_converter(pvscales); + auto qpvscale_ups = cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kEnableFastMath) { + // fast math: use reciprocal approximate to replace div + acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // regular path for slower math, use divide to replace div + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, + cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter(cutlass::multiplies>{}(compute_frgs_up, acc_scale), *reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, + acc_scale)); + } + } + + copy_if(tiled_r2g_QA, + [&](auto coord){ + Tensor tQApQA_view = group_modes<1,rank(tQApQA_mn)>(tQApQA_mn); + return elem_less(tQApQA_view(_0{}, coord), shape(mQA)); + }, + tQArQA, tQAgQA_mn); + // 32bit vectorization copy 4 e4m3 SFA for per 64/(16,4):(0, 1) element + constexpr int vec_len = 32 / sizeof_bits_v; + Tensor tQArSFA_v = recast>(filter(tQArSFA)); + Tensor tQAgSFA_v = recast>(filter(tQAgSFA_mn)); + copy_if( + [&](auto coord){ + Tensor tQApSFA_view = filter(tQApSFA_mn); + return elem_less(tQApSFA_view(_0{}, coord * vec_len), shape(mSFA)); + }, + tQArSFA_v, tQAgSFA_v); + } + scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); + ++clc_pipeline_consumer_state; + scheduler.update_work_tile_info(); + }while (scheduler.is_valid()); + } + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } +} + + +// this function computes RHT-GEMM for +// m = hidden_size, n = sequence_length +// A: m x n: col-major +// B: 16 x 16: row-major +// D: m x n: row-major +// SFD: m x (n/16): row-major +// QA: m x n: col-major +// SFA: m/16 x n: col-major +template +void row_col_rht_gemm_ntt_w_sfc( + int sequence_length, + int hidden_size, + TA const* A, + TB const* B, + TD* D, + TSFD* SFD, + TQA* QA, + TSFA* SFA, + float const* a_global_amax, + float const* d_global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) { + using namespace cute; + static int constexpr SFVecSize = 16; + static int constexpr RhtTensorSize = 16; + + static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16"); + using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int{}, 0), 0), make_stride(make_stride(_0{}, _1{}), 0))); + using LinearSFCLayout = decltype(make_layout(make_shape(0, make_shape(Int{}, 0)), make_stride(0, make_stride(_0{}, _1{})))); + + using SwizzledSFALayoutAtom = cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFALayout = decltype(tile_to_shape(SwizzledSFALayoutAtom{}, make_shape(hidden_size,sequence_length), Step<_1,_2>{})); + using SwizzledSFDLayout = decltype(tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(hidden_size,sequence_length), Step<_2,_1>{})); + + using SFALayout = cute::conditional_t; + using SFCLayout = cute::conditional_t; + SFALayout sfa_layout; + SFCLayout sfd_layout; + + if constexpr (kEnableSwizzleSFOutput) { + sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{}, make_shape(hidden_size, sequence_length), Step<_1,_2>{}); + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(hidden_size, sequence_length), Step<_2,_1>{}); + } else { + sfa_layout = make_layout(make_shape(make_shape(Int{}, hidden_size/SFVecSize), sequence_length), make_stride(make_stride(_0{}, _1{}), hidden_size/SFVecSize)); + sfd_layout = make_layout(make_shape(hidden_size, make_shape(Int{}, sequence_length/SFVecSize)), make_stride(sequence_length/SFVecSize, make_stride(_0{}, _1{}))); + } + // Define shapes (dynamic) + auto M = hidden_size; + auto N = sequence_length; + Tensor tensorA = make_tensor(A, make_shape(hidden_size, sequence_length), LayoutLeft{}); + Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); + Tensor tensorD = make_tensor(D, make_shape(hidden_size, sequence_length), LayoutRight{}); + Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, sequence_length), LayoutLeft{}); + Tensor tensorSFD = make_tensor(SFD, sfd_layout); + Tensor tensorSFA = make_tensor(SFA, sfa_layout); + // Define strides (from tensors) + auto dA = stride(tensorA); // (dM,dK) + auto dB = stride(tensorB); // (dN,dK) + auto dD = stride(tensorD); // (dM,dN) + auto dQA = stride(tensorQA); // (dM,dK) + using ClusterShape = Shape< _1, _1, _1>; + auto cluster_shape = ClusterShape{}; + auto cluster_tile_shape = Shape<_128,Int,Int>{}; + auto cluster_tile_mainloop = Shape<_128,Int,_128>{}; + + // Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles + static int constexpr EpilogueUnrollFactor = + size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape); + // Construct the MMA + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS(cluster_tile_shape), size<1>(cluster_tile_shape), + UMMA::Major::MN, UMMA::Major::MN>{}, + Layout>{}); + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div(shape<0>(cluster_tile_shape), shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cluster_tile_shape), shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape)); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + + auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + static uint32_t constexpr TotalTmemRows = 128; + static uint32_t constexpr Sm100TmemCapacityColumns = 512; + static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static uint32_t constexpr AccumulatorPipelineStageCount = + TotalTmem / + (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape)); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int SchedulerPipelineStageCount = 6; + static int constexpr MainloopPipelineBytes = sizeof(typename cutlass::detail::CustomizedPipelineTmaUmmaAsync< + 1, + Shape<_1,_1,_1>, + Shape<_1, _1, _1>>::SharedStorage); + + static int constexpr ClcResponseBytes = sizeof(CLCResponse) * SchedulerPipelineStageCount; + static int constexpr CLCThrottlePipelineBytes = sizeof(typename cutlass::PipelineAsync::SharedStorage); + static int constexpr CLCPipelineBytes = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier); + static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB); + static int constexpr AccPipelineBytes = sizeof(typename cutlass::PipelineUmmaAsync>::SharedStorage); + static int constexpr TmemBasePtrsBytes = sizeof(uint32_t); + static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes + static int constexpr kBytesPerStage = + cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes; + static int constexpr kReservedBytes = ClcResponseBytes + CLCThrottlePipelineBytes + TmemBasePtrsBytes + + CLCPipelineBytes + TmemDeallocBytes+BTensorBytes + AccPipelineBytes; // Reserve for barriers and other uses + static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(mma_shape_A, sP), Step<_2,_1,_3>{}); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1) + auto sD = Layout<_1>{}; // XXX Dummy + + auto tma_load_a = make_tma_copy_A_sm100( + SM90_TMA_LOAD{}, + tensorA, + sA(_,_,_,0), + cluster_tile_mainloop, + mma); + auto tma_load_b = make_tma_copy_B_sm100( + SM90_TMA_LOAD{}, + tensorB, + sB(_,_,_,0), + cluster_tile_shape, + mma); + + // Assert checks problem size should be multiple of 64 + assert(M % 64 == 0); + assert(N % 64 == 0); + + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile_shape)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(N, k_tile_size))); + uint32_t tiles = tiles_in_m * tiles_in_n; + + dim3 dimBlock(512); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); + dim3 dimGrid(tiles_in_m, tiles_in_n, 1); + + int smem_size = sizeof( + SharedStorage< + TA, + TB, + decltype(sA), + decltype(sB), + ClusterShape, + AccumulatorPipelineStageCount, + EpilogueUnrollFactor, + SchedulerPipelineStageCount>); + + auto* kernel_ptr = &row_col_rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), + decltype(cluster_shape), decltype(cluster_tile_shape), + TA, decltype(dA), decltype(sA), decltype(tma_load_a), + TB, decltype(dB), decltype(sB), decltype(tma_load_b), + TD, decltype(dD), decltype(sD), + TSFD, decltype(sfd_layout), + TQA, decltype(dQA), + TSFA, decltype(sfa_layout), + decltype(mma), + AccumulatorPipelineStageCount, + SchedulerPipelineStageCount, + kEnableStochasticRounding, + kEnableRHTColQuant, + kEnableRowQuant, + kEnableFastMath>; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; + cutlass::Status status = cutlass::launch_kernel_on_cluster( + params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, + tensorA.data(), dA, sA, tma_load_a, + tensorB.data(), dB, sB, tma_load_b, + tensorD.data(), dD, sD, + tensorSFD.data(), sfd_layout, + tensorQA.data(), dQA, + tensorSFA.data(), sfa_layout, + mma, a_global_amax, d_global_amax, rng_state); + + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed."); + +} + +} // namespace +} // namespace detail + +// clang-format on + +void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, + const Tensor &hadamard_matrix_, + QuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_cast_fusion); + + // Check input and output tensors + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + + // rowwise cast and columnwise cast has different output data pointers + bool has_rowwise_quant = false; + bool has_columnwise_quant = false; + void *rowwise_data_ptr = nullptr; + void *rowwise_scale_inv_ptr = nullptr; + void *rowwise_amax_ptr = nullptr; + void *columnwise_data_ptr = nullptr; + void *columnwise_scale_inv_ptr = nullptr; + void *columnwise_amax_ptr = nullptr; + + // examine the output tensor (single tensor for dense) + if (output_.data.dptr != nullptr) { + has_rowwise_quant = true; + rowwise_data_ptr = output_.data.dptr; + rowwise_scale_inv_ptr = output_.scale_inv.dptr; + rowwise_amax_ptr = output_.amax.dptr; + } else { + has_columnwise_quant = true; + columnwise_data_ptr = output_.columnwise_data.dptr; + columnwise_scale_inv_ptr = output_.columnwise_scale_inv.dptr; + columnwise_amax_ptr = output_.columnwise_amax.dptr; + } + + NVTE_CHECK(has_rowwise_quant || has_columnwise_quant, "Output tensor must have rowwise or columnwise quant."); + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TD = cutlass::float_e2m1_t; + using TSFD = cutlass::float_ue4m3_t; + using TQA = TD; + using TSFA = TSFD; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Hadamard matrix must be BF16 tensor, but scaling mode is ", + to_string(hadamard_matrix_.scaling_mode), "."); + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + // TODO: add support for swizzle sf output + const bool use_swizzle_sf_output = false; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + has_columnwise_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + has_rowwise_quant, kEnableRowQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TD, TSFD, TQA, TSFA, kUseFastMath>( + /*sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*D=*/reinterpret_cast(columnwise_data_ptr), + /*SFD=*/reinterpret_cast(columnwise_scale_inv_ptr), + /*QA=*/reinterpret_cast(rowwise_data_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_ptr), + /*a_global_amax=*/reinterpret_cast(rowwise_amax_ptr), + /*d_global_amax=*/reinterpret_cast(columnwise_amax_ptr), + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); + } + + ););););); +} + +} // namespace transformer_engine + +void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_cast_fusion); + using namespace transformer_engine; + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + hadamard_transform_cast_fusion( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index bee939f0cd..5dd44e1cbe 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -49,6 +49,7 @@ void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int /*! \brief Perform the columnwise hadamard transform cast fusion. * * This function is experimental and the API is not stable. + * This function will later be deprecated and replaced by nvte_hadamard_transform_cast_fusion * * \param[in] input Input tensor to apply Hadamard transform. * \param[in,out] output Output tensor. @@ -61,6 +62,21 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! \brief Perform the regular rowwise cast and columnwise hadamard transform cast fusion. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] hadamard_matrix Hadamard matrix. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ + void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + /*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split. * * This function is experimental and the API is not stable. From efad808e6b790cc03b265d2f006e5eb94d445f3e Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 2 Jan 2026 16:39:50 -0800 Subject: [PATCH 02/22] pass numerical unit test Signed-off-by: Zhongbo Zhu --- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 64 +++-- ...cast_col_hadamard_transform_cast_fusion.cu | 64 ++++- transformer_engine/pytorch/csrc/common.h | 1 + .../pytorch/csrc/extensions/cast.cpp | 5 + transformer_engine/pytorch/csrc/quantizer.cpp | 220 ++++++++++-------- 5 files changed, 227 insertions(+), 127 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f54..5826c4b95f 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -35,6 +35,7 @@ def check_quantization_nvfp4_versus_reference( M: int, N: int, contiguous: bool, + return_identity: bool, return_transpose: bool, use_cpp_allocator: bool, swizzled_scale: bool = False, @@ -61,7 +62,7 @@ def check_quantization_nvfp4_versus_reference( # Quantize nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, - rowwise=True, + rowwise=return_identity, columnwise=return_transpose, with_amax_reduction=False, amax_reduction_group=None, @@ -78,9 +79,11 @@ def check_quantization_nvfp4_versus_reference( x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) # Extract data from NVFP4Tensor - assert x_nvfp4_sut._rowwise_data is not None - qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) - assert x_nvfp4_sut._rowwise_scale_inv is not None + qx: torch.Tensor = ( + x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._rowwise_data is not None + else None + ) sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv qx_t = ( x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) @@ -91,13 +94,13 @@ def check_quantization_nvfp4_versus_reference( amax_rowwise = x_nvfp4_sut._amax_rowwise amax_colwise = x_nvfp4_sut._amax_columnwise - qx = unpack_fp4(qx) + qx = unpack_fp4(qx) if qx is not None else None qx_t = unpack_fp4(qx_t) if qx_t is not None else None # Reference quantization using NVFP4QuantizerRef with built-in RHT ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, - rowwise=True, + rowwise=return_identity, columnwise=return_transpose, pow_2_scales=False, eps=0.0, @@ -130,13 +133,14 @@ def check_quantization_nvfp4_versus_reference( sx_t_ref = None ref_amax_colwise_t = None - torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) + if return_identity: + torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) - torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) - # Compare only the valid portion of scale tensors (reference may not have padding) - ref_sx_shape = sx_ref.shape - sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] - torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) if return_transpose: torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0) @@ -185,7 +189,7 @@ def check_quantization_nvfp4_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -197,15 +201,29 @@ def test_rht_with_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, N: int, - return_transpose: bool, + quantize_mode: str, use_cpp_allocator: bool, with_random_sign_mask: bool, ) -> None: + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, N=N, contiguous=True, + return_identity=return_identity, return_transpose=return_transpose, use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, @@ -221,7 +239,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -233,15 +251,29 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, - return_transpose: bool, + quantize_mode: str, use_cpp_allocator: bool, with_random_sign_mask: bool, ): + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, N=N, contiguous=False, + return_identity=return_identity, return_transpose=return_transpose, use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 14aa4e0cdb..1f6898e3d3 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -149,7 +149,7 @@ template + bool kUseFastMath_ = true> __launch_bounds__(512, 1) __global__ static void row_col_rht_gemm_device( MShape M, @@ -189,7 +189,7 @@ __global__ static void row_col_rht_gemm_device( static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; static constexpr bool kEnableRowQuant = kEnableRowQuant_; - static constexpr bool kEnableFastMath = kEnableFastMath_; + static constexpr bool kUseFastMath = kUseFastMath_; static int constexpr RhtTensorSize = 16; static int constexpr kTmaRhtTensorTransactionBytes = cutlass::bits_to_bytes( RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); @@ -700,7 +700,11 @@ __global__ static void row_col_rht_gemm_device( : 1.0f; float const global_decode_scale = 1.0f / global_encode_scale; - float const global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } auto sfc_converter = cutlass::NumericConverter{}; do { @@ -752,6 +756,19 @@ __global__ static void row_col_rht_gemm_device( accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); ++accumulator_pipe_consumer_state; + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); + } + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); CUTLASS_PRAGMA_UNROLL @@ -759,7 +776,17 @@ __global__ static void row_col_rht_gemm_device( vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); } - pvscales = cutlass::multiplies>{}(vec_maxs, global_encode_scale_multiplier); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); + } auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); tD_rRowSFD_frg(_0{}) = pvscales_cvted; @@ -769,7 +796,7 @@ __global__ static void row_col_rht_gemm_device( global_decode_scale); cutlass::Array acc_scales; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { // fast math: use reciprocal approximate to replace div acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { @@ -882,7 +909,11 @@ __global__ static void row_col_rht_gemm_device( : 1.0f; float const global_decode_scale = 1.0f / global_encode_scale; - float const global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } auto sfa_converter = cutlass::NumericConverter{}; do { uint32_t skip_wait = K_TILE_MAX <= 0; @@ -916,12 +947,21 @@ __global__ static void row_col_rht_gemm_device( for (int v = 0; v < size(tQArA)/VectorSize; v++) { auto compute_frgs_up = cutlass::NumericArrayConverter{}(compute_frgs[v]); auto amax = amax_reduction(ElementAccumulator(0), compute_frgs_up); - auto pvscales= cutlass::multiplies{}(amax, global_encode_scale_multiplier); + // declare pvscales + ElementAccumulator pvscales; + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies{}(amax, global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales = cutlass::divides{}(amax, fp4_max); + pvscales = cutlass::multiplies{}(pvscales, global_encode_scale); + } filter(tQArSFA)(v) = sfa_converter(pvscales); auto qpvscale_ups = cutlass::NumericConverter{}(filter(tQArSFA)(v)); auto qpvscale_scaled = cutlass::multiplies{}(qpvscale_ups, global_decode_scale); ElementAccumulator acc_scales; - if constexpr (kEnableFastMath) { + if constexpr (kUseFastMath) { // fast math: use reciprocal approximate to replace div acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { @@ -980,7 +1020,7 @@ __global__ static void row_col_rht_gemm_device( // QA: m x n: col-major // SFA: m/16 x n: col-major template +class TA, class TB, class TD, class TSFD, class TQA, class TSFA, bool kUseFastMath=true> void row_col_rht_gemm_ntt_w_sfc( int sequence_length, int hidden_size, @@ -1160,7 +1200,7 @@ void row_col_rht_gemm_ntt_w_sfc( kEnableStochasticRounding, kEnableRHTColQuant, kEnableRowQuant, - kEnableFastMath>; + kUseFastMath>; NVTE_CHECK_CUDA(cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -1216,7 +1256,9 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, rowwise_data_ptr = output_.data.dptr; rowwise_scale_inv_ptr = output_.scale_inv.dptr; rowwise_amax_ptr = output_.amax.dptr; - } else { + } + + if (output_.columnwise_data.dptr != nullptr) { has_columnwise_quant = true; columnwise_data_ptr = output_.columnwise_data.dptr; columnwise_scale_inv_ptr = output_.columnwise_scale_inv.dptr; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 6aab9938b3..5f1609f4cd 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -370,6 +370,7 @@ class NVFP4Quantizer : public Quantizer { private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); + void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, cudaStream_t stream); }; std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f8f793f036..4459fdcc65 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -997,6 +997,11 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // Enable NVFP4 kernels to use math operations that sacrifice // accuracy for performance. These optimizations are experimental // and inconsistently implemented. + // What math is accelerated? Only the high precision math, so numerical impact is minimal + // 1. replace x / y by x * (1/y) + // 2. replace 1 / x by reciporal_approximate_ftz(x) + // 3. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, + // this will essentially remove a round trip between FP32 to BF16 then FP32 const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math) { for (auto &config : quant_config_list) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0da5f69197..8be306adca 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -9,6 +9,7 @@ #include "common.h" #include "pybind.h" #include "torch/torch.h" +#include "common/util/system.h" namespace transformer_engine::pytorch { @@ -2030,6 +2031,79 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( return {std::move(out_cpp), std::move(tensor)}; } +void NVFP4Quantizer::quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, cudaStream_t stream){ + // only triggered for irregular shapes where RHT cast fusion kernel is not eligible + if (rowwise_usage) { + // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise + TensorWrapper out_identity(out.scaling_mode()); + auto out_identity_data = out.get_rowwise_data(); + auto out_identity_scale_inv = out.get_rowwise_scale_inv(); + auto out_identity_amax = out.get_amax(); + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); + } + + if (columnwise_usage) { + // Get the output columnwise data, scale_inv, and amax + auto out_columnwise_data = out.get_columnwise_data(); + auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); + // NOTE: should already be populated. + auto out_columnwise_amax = out.get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. + // The reason is due to the input `rht_output_t` is already in the transposed layout. + // Thus, we only need a rowwise quantization to generate the columnwise output. + TensorWrapper out_transpose(out.scaling_mode()); + // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail + // need to convert the shape to 2D here + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte + // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again + // so the multiple 2 get cancelled out + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + + // Invoking fallback RHT kernel unfused. + + NVTE_SCOPED_GIL_RELEASE({ + // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + + // Quantize kernel will treat everything as rowwise input/output, which is + // intended. + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config_columnwise, + stream); + }); + } +} + void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax) { @@ -2055,14 +2129,25 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT + bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + // Stochastic rounding // When both rowwise and columnwise quantization are used with RHT, // we need separate RNG states for each to ensure they use different random numbers. TensorWrapper te_rng_state; TensorWrapper te_rng_state_columnwise; QuantizationConfigWrapper quant_config_columnwise; + + // Only need a separate rng state when: + // 1. Stochastic rounding is enabled + // 2. RHT is enabled + // 3. Columnwise usage is enabled + // 4. Rowwise and columnwise quantization are not fused, + // because within a single kernel we can generate two different random numbers for rowwise and columnwise const bool need_separate_columnwise_rng = - this->stochastic_rounding && this->with_rht && this->columnwise_usage; + this->stochastic_rounding && this->with_rht && this->columnwise_usage && (!eligible_for_rht_cast_fusion); if (this->stochastic_rounding) { const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened @@ -2088,10 +2173,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } } - // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT - bool eligible_for_rht_cast_fusion = - input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; - // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { @@ -2157,103 +2238,42 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); } - if (this->with_rht) { - if (rowwise_usage) { - // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise - TensorWrapper out_identity(out.scaling_mode()); - auto out_identity_data = out.get_rowwise_data(); - auto out_identity_scale_inv = out.get_rowwise_scale_inv(); - auto out_identity_amax = out.get_amax(); - out_identity.set_rowwise_data(out_identity_data.data_ptr, - static_cast(out_identity_data.dtype), - out_identity_data.shape); - out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, - static_cast(out_identity_scale_inv.dtype), - out_identity_scale_inv.shape); - out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), - out_identity_amax.shape); - - NVTE_SCOPED_GIL_RELEASE( - { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); - } - - if (columnwise_usage) { - // Get the output columnwise data, scale_inv, and amax - auto out_columnwise_data = out.get_columnwise_data(); - auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); - // NOTE: should already be populated. - auto out_columnwise_amax = out.get_columnwise_amax(); - - // Create a wrapper for the columnwise output, as the rowwise output. - // The reason is due to the input `rht_output_t` is already in the transposed layout. - // Thus, we only need a rowwise quantization to generate the columnwise output. - TensorWrapper out_transpose(out.scaling_mode()); - // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail - // need to convert the shape to 2D here - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte - // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again - // so the multiple 2 get cancelled out - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; - } - colwise_data_shape_2d.push_back(last_dim); - - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, - static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); - out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, - static_cast(out_columnwise_scale_inv.dtype), - out_columnwise_scale_inv.shape); - out_transpose.set_amax(out_columnwise_amax.data_ptr, - static_cast(out_columnwise_amax.dtype), - out_columnwise_amax.shape); + // Fast math toggle: RHT transform can be accelerated + // What math is accelerated? Only the high precision math, so numerical impact is minimal + // 1. replace x / y by x * (1/y) + // 2. replace 1 / x by reciporal_approximate_ftz(x) + // 3. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, + // this will essentially remove a round trip between FP32 to BF16 then FP32 + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + quant_config.set_use_fast_math(true); + quant_config_columnwise.set_use_fast_math(true); + } + if (this->with_rht) { + if (eligible_for_rht_cast_fusion) { + // fusion kernel requires passing in RHT matrix directly for maximum performance + auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); + // Fusion kernel that does the following: + // 1. Rowwise quantization + // 2. RHT followed by columnwise quantization & transpose + NVTE_SCOPED_GIL_RELEASE({ nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(), quant_config, stream); }); + } else { // Use separate RNG state for columnwise to ensure different random numbers than rowwise - auto& columnwise_quant_config = - need_separate_columnwise_rng ? quant_config_columnwise : quant_config; - - if (!eligible_for_rht_cast_fusion) { - // Invoking fallback RHT kernel. - - // If using RHT, then amax will be computed in the RHT step - // If not using RHT, then amax will be computed based on input x - at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout - // This wrapper is going to be passed as input to the quantization kernel. - TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs - rht_output_t = - allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); - // NOTE (frsun): This is non-intuitive, we are writing the - // result of transposed RHT to the output of rowwise. - rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), - std::vector{cols, rows}); - - NVTE_SCOPED_GIL_RELEASE({ - // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. - nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); - - // Quantize kernel will treat everything as rowwise input/output, which is - // intended. - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), columnwise_quant_config, - stream); - }); - } else { - // RHT cast fusion kernel. - NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, - "RHT matrix is not set"); - auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); - NVTE_SCOPED_GIL_RELEASE({ - nvte_hadamard_transform_cast_fusion_columnwise(input.data(), out_transpose.data(), - rht_matrix_nvte.data(), - columnwise_quant_config, stream); - }); - } + // This is only necessary because it's the unfused path where rowwise and columnwise + // are separate kernel launches + auto& columnwise_quant_config_to_use = + need_separate_columnwise_rng ? quant_config_columnwise : quant_config; + // unfused path also needs memory allocation for intermediate buffer for RHT output + at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout + // This wrapper is going to be passed as input to the quantization kernel. + TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs + rht_output_t = allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + // NOTE (frsun): This is non-intuitive, we are writing the + // result of transposed RHT to the output of rowwise. + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config, columnwise_quant_config_to_use, stream); } } else { NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); From c32537956ce70e50f65abfee80f218480e7daf6f Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 2 Jan 2026 17:15:04 -0800 Subject: [PATCH 03/22] format Signed-off-by: Zhongbo Zhu --- ...cast_col_hadamard_transform_cast_fusion.cu | 206 +++++++++--------- .../transformer_engine/hadamard_transform.h | 8 +- transformer_engine/pytorch/csrc/common.h | 6 +- transformer_engine/pytorch/csrc/quantizer.cpp | 37 ++-- 4 files changed, 135 insertions(+), 122 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 1f6898e3d3..99fa711b3a 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -4,41 +4,41 @@ * See LICENSE for license information. ************************************************************************/ - #include - #include - #include - #include - #include - #include - #include - - #include - #include - #include - #include - - #include "common/common.h" - #include "common/util/cuda_runtime.h" - #include "common/util/curanddx.hpp" - #include "common/util/ptx.cuh" - #include "common/utils.cuh" - #include "customized_pipeline.cuh" - #include "cutlass/arch/barrier.h" - #include "cutlass/arch/reg_reconfig.h" - #include "cutlass/cluster_launch.hpp" - #include "cutlass/cutlass.h" - #include "cutlass/detail/sm100_blockscaled_layout.hpp" - #include "cutlass/fast_math.h" - #include "cutlass/float8.h" - #include "cutlass/float_subbyte.h" - #include "cutlass/gemm/collective/builders/sm100_common.inl" - #include "cutlass/numeric_conversion.h" - #include "cutlass/numeric_types.h" - #include "cutlass/pipeline/pipeline.hpp" - #include "cutlass/platform/platform.h" - #include "cutlass/util/GPU_Clock.hpp" - #include "cutlass/util/command_line.h" - #include "cutlass/util/print_error.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "customized_pipeline.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/float8.h" +#include "cutlass/float_subbyte.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/platform/platform.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/print_error.hpp" // clang-format off @@ -79,7 +79,7 @@ cutlass::Array StochasticNumericConverterBase( } CUTLASS_DEVICE -cutlass::Array +cutlass::Array StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const &rbits) { using result_type = cutlass::Array; result_type output; @@ -255,7 +255,7 @@ __global__ static void row_col_rht_gemm_device( CUTLASS_DEVICE uint32_t tile_m() const { return work_tile_info.m_idx; } - CUTLASS_DEVICE uint32_t tile_n_base() const { + CUTLASS_DEVICE uint32_t tile_n_base() const { return work_tile_info.n_idx * uint32_t(k_tile_max); } @@ -265,7 +265,7 @@ __global__ static void row_col_rht_gemm_device( return cute::elem_less(cute::make_coord(work_tile_info.m_idx, work_tile_info.n_idx), cute::make_coord(tiles_in_m, tiles_in_n)) && work_tile_info.is_valid_tile; } CUTLASS_DEVICE bool is_first_wave() const { return wave_cnt == 0; } - CUTLASS_DEVICE auto advance_to_next_work(CLCPipeline& clc_pipeline, CLCPipelineState clc_pipe_producer_state) { + CUTLASS_DEVICE auto advance_to_next_work(CLCPipeline& clc_pipeline, CLCPipelineState clc_pipe_producer_state) { uint32_t mbarrier_addr = clc_pipeline.producer_get_barrier(clc_pipe_producer_state); // Wait for clcID buffer to become empty with a flipped phase clc_pipeline.producer_acquire(clc_pipe_producer_state); @@ -304,7 +304,7 @@ __global__ static void row_col_rht_gemm_device( &clc_response_ptr[state.index()])); asm volatile( "{\n\t" - "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n\t" + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n\t" "}\n" : : "r"(result_addr), "r"(mbarrier_addr)); @@ -482,30 +482,30 @@ __global__ static void row_col_rht_gemm_device( Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) - + int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) - + Layout cta_layout_mnk = make_layout(cluster_shape); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); - + auto [tAgA, tAsA] = tma_partition( tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); - + auto [tBgB, tBsB] = tma_partition( tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); - + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); if constexpr (kEnableRHTColQuant) { @@ -549,14 +549,14 @@ __global__ static void row_col_rht_gemm_device( scheduler.update_work_tile_info(); } while (scheduler.is_valid()); mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); - } + } else if (is_mma_warp) { cutlass::arch::warpgroup_reg_dealloc<32>(); if constexpr (kEnableRHTColQuant) { Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) - + int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx // Allocate "fragments" -- these are actually umma smem descriptors @@ -674,7 +674,7 @@ __global__ static void row_col_rht_gemm_device( pD, epilogue_tiler, make_coord(_,_, _), - Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Step<_1,_1, X>{}); // (BLK_M,BLK_N) Tensor mSFD = make_tensor(make_gmem_ptr(SFD), sfd_layout); Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) Tensor pSFD = make_identity_tensor(mSFD.shape()); @@ -717,7 +717,7 @@ __global__ static void row_col_rht_gemm_device( Tensor tDpSFD_mn = pSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); - + auto Acc = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) @@ -788,7 +788,7 @@ __global__ static void row_col_rht_gemm_device( pvscales, global_encode_scale); } auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); - + tD_rRowSFD_frg(_0{}) = pvscales_cvted; auto qpvscale_ups = cutlass::NumericArrayConverter{}(tD_rRowSFD_frg(_0{})); auto qpvscale_scaled = cutlass::multiplies>{}( @@ -823,18 +823,18 @@ __global__ static void row_col_rht_gemm_device( output_frgs[v] = cutlass::NumericArrayConverter{}( cutlass::multiplies>{}( compute_frgs[v], - acc_scale)); + acc_scale)); } } - copy_if(tiled_r2g, + copy_if(tiled_r2g, [&](auto coord){ Tensor pSrc_view = group_modes<1,rank(pSrc)>(pSrc); return elem_less(pSrc_view(_0{},coord), shape(mD)); }, src, dst); - // 32bit vectorization copy 4 e4m3 SFD for per 64/(16,4):(0, 1) element + // 32bit vectorization copy 4 e4m3 SFD for per 64/(16,4):(0, 1) element constexpr int vec_len = 32 / sizeof_bits_v; Tensor tDrSFD_v = recast>(tDrSFD); Tensor tDgSFD_v = recast>(tDgSFD); @@ -887,7 +887,7 @@ __global__ static void row_col_rht_gemm_device( auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) - + Tensor tQArA = make_tensor_like(make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) // Tensor tQArA_PI = thr_s2r.partition_S(sA_PI); Tensor tQAgQA = thr_r2g_QA.partition_D(gQA_mn); @@ -979,21 +979,21 @@ __global__ static void row_col_rht_gemm_device( output_frgs[v] = cutlass::NumericArrayConverter{}( cutlass::multiplies>{}( compute_frgs_up, - acc_scale)); + acc_scale)); } } - copy_if(tiled_r2g_QA, + copy_if(tiled_r2g_QA, [&](auto coord){ Tensor tQApQA_view = group_modes<1,rank(tQApQA_mn)>(tQApQA_mn); return elem_less(tQApQA_view(_0{}, coord), shape(mQA)); }, tQArQA, tQAgQA_mn); - // 32bit vectorization copy 4 e4m3 SFA for per 64/(16,4):(0, 1) element + // 32bit vectorization copy 4 e4m3 SFA for per 64/(16,4):(0, 1) element constexpr int vec_len = 32 / sizeof_bits_v; Tensor tQArSFA_v = recast>(filter(tQArSFA)); Tensor tQAgSFA_v = recast>(filter(tQAgSFA_mn)); - copy_if( + copy_if( [&](auto coord){ Tensor tQApSFA_view = filter(tQApSFA_mn); return elem_less(tQApSFA_view(_0{}, coord * vec_len), shape(mSFA)); @@ -1036,10 +1036,10 @@ void row_col_rht_gemm_ntt_w_sfc( uint32_t sm_count, cudaStream_t stream, int k_tile_size = 1024) { - using namespace cute; + using namespace cute; static int constexpr SFVecSize = 16; static int constexpr RhtTensorSize = 16; - + static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16"); using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int{}, 0), 0), make_stride(make_stride(_0{}, _1{}), 0))); using LinearSFCLayout = decltype(make_layout(make_shape(0, make_shape(Int{}, 0)), make_stride(0, make_stride(_0{}, _1{})))); @@ -1169,7 +1169,7 @@ void row_col_rht_gemm_ntt_w_sfc( uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile_shape)))); uint32_t tiles_in_n = uint32_t(size(ceil_div(N, k_tile_size))); uint32_t tiles = tiles_in_m * tiles_in_n; - + dim3 dimBlock(512); dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); dim3 dimGrid(tiles_in_m, tiles_in_n, 1); @@ -1196,14 +1196,14 @@ void row_col_rht_gemm_ntt_w_sfc( TSFA, decltype(sfa_layout), decltype(mma), AccumulatorPipelineStageCount, - SchedulerPipelineStageCount, + SchedulerPipelineStageCount, kEnableStochasticRounding, kEnableRHTColQuant, kEnableRowQuant, kUseFastMath>; NVTE_CHECK_CUDA(cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream}; cutlass::Status status = cutlass::launch_kernel_on_cluster( params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, @@ -1226,9 +1226,8 @@ void row_col_rht_gemm_ntt_w_sfc( // clang-format on void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, - const Tensor &hadamard_matrix_, - QuantizationConfig quant_config, - cudaStream_t stream) { + const Tensor &hadamard_matrix_, QuantizationConfig quant_config, + cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_cast_fusion); // Check input and output tensors @@ -1239,7 +1238,7 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); const SimpleTensor &input = input_.data; - + // rowwise cast and columnwise cast has different output data pointers bool has_rowwise_quant = false; bool has_columnwise_quant = false; @@ -1265,7 +1264,8 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, columnwise_amax_ptr = output_.columnwise_amax.dptr; } - NVTE_CHECK(has_rowwise_quant || has_columnwise_quant, "Output tensor must have rowwise or columnwise quant."); + NVTE_CHECK(has_rowwise_quant || has_columnwise_quant, + "Output tensor must have rowwise or columnwise quant."); // Stochastic rounding config const bool use_stochastic_rounding = quant_config.stochastic_rounding; @@ -1324,52 +1324,52 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, const bool use_swizzle_sf_output = false; TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, kEnableStochasticRounding, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - has_columnwise_quant, kEnableRhtColQuant, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - has_rowwise_quant, kEnableRowQuant, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_swizzle_sf_output, kEnableSwizzleSFOutput, + use_stochastic_rounding, kEnableStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + has_columnwise_quant, kEnableRhtColQuant, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + has_rowwise_quant, kEnableRowQuant, TRANSFORMER_ENGINE_SWITCH_CONDITION( - quant_config.use_fast_math, kUseFastMath, - - if constexpr (kEnableRhtColQuant || kEnableRowQuant) { - detail::row_col_rht_gemm_ntt_w_sfc< - kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, - kEnableSwizzleSFOutput, TA, TB, TD, TSFD, TQA, TSFA, kUseFastMath>( - /*sequence_length=*/m, /*hidden_size=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*D=*/reinterpret_cast(columnwise_data_ptr), - /*SFD=*/reinterpret_cast(columnwise_scale_inv_ptr), - /*QA=*/reinterpret_cast(rowwise_data_ptr), - /*SFA=*/reinterpret_cast(rowwise_scale_inv_ptr), - /*a_global_amax=*/reinterpret_cast(rowwise_amax_ptr), - /*d_global_amax=*/reinterpret_cast(columnwise_amax_ptr), - /*rng_state=*/rng_state, /*sm_count=*/sm_count, - /*stream=*/stream, /*k_tile_size=*/k_tile_size); - } else { - NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", - kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); - } - - ););););); + use_swizzle_sf_output, kEnableSwizzleSFOutput, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.use_fast_math, kUseFastMath, + + if constexpr (kEnableRhtColQuant || kEnableRowQuant) { + detail::row_col_rht_gemm_ntt_w_sfc< + kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant, + kEnableSwizzleSFOutput, TA, TB, TD, TSFD, TQA, TSFA, kUseFastMath>( + /*sequence_length=*/m, /*hidden_size=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*D=*/reinterpret_cast(columnwise_data_ptr), + /*SFD=*/reinterpret_cast(columnwise_scale_inv_ptr), + /*QA=*/reinterpret_cast(rowwise_data_ptr), + /*SFA=*/reinterpret_cast(rowwise_scale_inv_ptr), + /*a_global_amax=*/reinterpret_cast(rowwise_amax_ptr), + /*d_global_amax=*/reinterpret_cast(columnwise_amax_ptr), + /*rng_state=*/rng_state, /*sm_count=*/sm_count, + /*stream=*/stream, /*k_tile_size=*/k_tile_size); + } else { + NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=", + kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ")."); + } + + ););););); } } // namespace transformer_engine void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, - const NVTETensor hadamard_matrix, - const NVTEQuantizationConfig quant_config, - cudaStream_t stream) { + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { NVTE_API_CALL(nvte_hadamard_transform_cast_fusion); using namespace transformer_engine; QuantizationConfig quant_config_cpp; if (quant_config != nullptr) { quant_config_cpp = *reinterpret_cast(quant_config); } - hadamard_transform_cast_fusion( - *convertNVTETensorCheck(input), *convertNVTETensorCheck(output), - *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); + hadamard_transform_cast_fusion(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, + stream); } diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 5dd44e1cbe..75729967a3 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -72,10 +72,10 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE * \param[in] quant_config Quantization configuration. * \param[in] stream CUDA stream used for the operation. */ - void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, - const NVTETensor hadamard_matrix, - const NVTEQuantizationConfig quant_config, - cudaStream_t stream); +void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); /*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split. * diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 5f1609f4cd..63a2e86e67 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -370,7 +370,11 @@ class NVFP4Quantizer : public Quantizer { private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); - void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, cudaStream_t stream); + void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, + TensorWrapper& rht_output_t_cpp, + QuantizationConfigWrapper& quant_config, + QuantizationConfigWrapper& quant_config_columnwise, + cudaStream_t stream); }; std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 8be306adca..1a9af70680 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -7,9 +7,9 @@ #include #include "common.h" +#include "common/util/system.h" #include "pybind.h" #include "torch/torch.h" -#include "common/util/system.h" namespace transformer_engine::pytorch { @@ -2031,7 +2031,10 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( return {std::move(out_cpp), std::move(tensor)}; } -void NVFP4Quantizer::quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, cudaStream_t stream){ +void NVFP4Quantizer::quantize_with_rht_unfused_helper( + const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, + QuantizationConfigWrapper& quant_config, QuantizationConfigWrapper& quant_config_columnwise, + cudaStream_t stream) { // only triggered for irregular shapes where RHT cast fusion kernel is not eligible if (rowwise_usage) { // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise @@ -2087,7 +2090,7 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper(const TensorWrapper& input static_cast(out_columnwise_amax.dtype), out_columnwise_amax.shape); - // Invoking fallback RHT kernel unfused. + // Invoking fallback RHT kernel unfused. NVTE_SCOPED_GIL_RELEASE({ // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. @@ -2099,7 +2102,7 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper(const TensorWrapper& input // intended. NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config_columnwise, - stream); + stream); }); } } @@ -2140,14 +2143,15 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou TensorWrapper te_rng_state_columnwise; QuantizationConfigWrapper quant_config_columnwise; - // Only need a separate rng state when: + // Only need a separate rng state when: // 1. Stochastic rounding is enabled // 2. RHT is enabled // 3. Columnwise usage is enabled // 4. Rowwise and columnwise quantization are not fused, // because within a single kernel we can generate two different random numbers for rowwise and columnwise - const bool need_separate_columnwise_rng = - this->stochastic_rounding && this->with_rht && this->columnwise_usage && (!eligible_for_rht_cast_fusion); + const bool need_separate_columnwise_rng = this->stochastic_rounding && this->with_rht && + this->columnwise_usage && + (!eligible_for_rht_cast_fusion); if (this->stochastic_rounding) { const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened @@ -2256,24 +2260,29 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); // Fusion kernel that does the following: // 1. Rowwise quantization - // 2. RHT followed by columnwise quantization & transpose - NVTE_SCOPED_GIL_RELEASE({ nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(), quant_config, stream); }); + // 2. RHT followed by columnwise quantization & transpose + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(), + quant_config, stream); + }); } else { // Use separate RNG state for columnwise to ensure different random numbers than rowwise - // This is only necessary because it's the unfused path where rowwise and columnwise + // This is only necessary because it's the unfused path where rowwise and columnwise // are separate kernel launches auto& columnwise_quant_config_to_use = - need_separate_columnwise_rng ? quant_config_columnwise : quant_config; + need_separate_columnwise_rng ? quant_config_columnwise : quant_config; // unfused path also needs memory allocation for intermediate buffer for RHT output at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout // This wrapper is going to be passed as input to the quantization kernel. TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs - rht_output_t = allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); // NOTE (frsun): This is non-intuitive, we are writing the // result of transposed RHT to the output of rowwise. rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), - std::vector{cols, rows}); - this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config, columnwise_quant_config_to_use, stream); + std::vector{cols, rows}); + this->quantize_with_rht_unfused_helper(input, out, rht_output_t_cpp, quant_config, + columnwise_quant_config_to_use, stream); } } else { NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); From 400b526b13d2002773f84143202309127c242aa7 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 2 Jan 2026 20:07:54 -0800 Subject: [PATCH 04/22] add benchmark script Signed-off-by: Zhongbo Zhu --- benchmarks/linear/benchmark_linear.py | 335 ++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 benchmarks/linear/benchmark_linear.py diff --git a/benchmarks/linear/benchmark_linear.py b/benchmarks/linear/benchmark_linear.py new file mode 100644 index 0000000000..3bbcb804ed --- /dev/null +++ b/benchmarks/linear/benchmark_linear.py @@ -0,0 +1,335 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import torch.utils.benchmark as benchmark +import pandas as pd + +from transformer_engine.pytorch.module import Linear as TELinear +from transformer_engine.common.recipe import ( + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) +from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager +from contextlib import nullcontext + +""" +# Profile BF16 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_bf16 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe bf16 + +# Profile FP8 sub-channel recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_fp8_sub_channel \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe fp8_sub_channel + +# Profile MXFP8 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_mxfp8 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe mxfp8 + +# Profile NVFP4 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_linear_nvfp4_rht_cast_fusion \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_linear.py --profile --recipe nvfp4 + +# Example to look at a single kernel target with NCU, like the fused hadamard amax kernel for NVFP4 recipe +ncu -f -o ./benchmarks/linear/ncu_b200_linear_nvfp4_rht_cast_fusion \ + --set=full \ + --kernel-name "row_col_rht_gemm_device" \ + -s 5 -c 5 \ + python benchmarks/linear/benchmark_linear.py --profile --recipe nvfp4 + +""" + +RECIPES = { + "bf16": None, + "fp8_sub_channel": Float8BlockScaling(), + "mxfp8": MXFP8BlockScaling(), + "nvfp4": NVFP4BlockScaling(), +} + +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() + + +def run_linear_multiple_steps(layer, x, mode, gradient, run_num_steps=1, recipe=None): + assert mode in ["fwd_only", "fwd_bwd"] + quantization_context = ( + autocast(enabled=True, recipe=recipe) if recipe is not None else nullcontext() + ) + + if mode == "fwd_only": + with torch.no_grad(), quantization_context: + for i in range(run_num_steps): + y_q = layer.forward( + x, + is_first_microbatch=(i == 0), + ) + return y_q + else: + # reset gradients + layer.zero_grad() + x.grad = None + + with quantization_context: + for i in range(run_num_steps): + label = f"step_{i}" + torch.cuda.nvtx.range_push(label) + y_q = layer.forward( + x, + is_first_microbatch=(i == 0), + ) + y_q.backward(gradient) + torch.cuda.nvtx.range_pop() + + grads_q = [] + grads_q.append(x.grad) + # remaining derivatives are in respect to model parameters + for p in layer.parameters(): + if p.requires_grad: + grads_q.append(p.grad) + + return y_q, grads_q + + +def benchmark_linear( + x, + w, + bias, + recipe_name, + mode, +): + params_dtype = torch.bfloat16 + recipe = RECIPES[recipe_name] + + in_features = x.shape[1] + out_features = w.shape[0] + gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device) + + layer = TELinear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + ) + + layer = layer.to("cuda") + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + num_microbatches = 32 + + label = f"{recipe_name}_{'linear'}" + torch.cuda.nvtx.range_push(label) + timing = benchmark.Timer( + stmt=( + "run_linear_multiple_steps(layer, x, mode, gradient, num_microbatches," + " recipe)" + ), + globals={ + "run_linear_multiple_steps": run_linear_multiple_steps, + "layer": layer, + "x": x, + "mode": mode, + "gradient": gradient, + "num_microbatches": num_microbatches, + "recipe": recipe, + }, + num_threads=1, + ).blocked_autorange(min_run_time=10) + print(f"{recipe_name}: {timing} \n") + timing_ms = timing.median * 1000 / num_microbatches + + return timing_ms + + +def run_benchmark_linear( + mkns, recipe_name, use_bias, fwd_only=False +): + data = [] + assert not use_bias, "Bias is not supported in this benchmark script" + + print(f"========== Benchmarking {recipe_name} ==========") + for m, k, n in mkns: + device = "cuda" + x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) + w = torch.randn((n, k), dtype=torch.bfloat16, device=device) + bias = None + + # Run the benchmark + print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") + print(f"fwd_only: {fwd_only}") + + linear_fwd_bwd_timing_ms = benchmark_linear( + x, + w, + bias, + recipe_name, + mode="fwd_only" if fwd_only else "fwd_bwd", + ) + + # Append the results + data.append( + [ + m, + k, + n, + recipe_name, + linear_fwd_bwd_timing_ms, + ] + ) + + timing_notation = "linear_fwd_time_ms" if fwd_only else "linear_fwd_bwd_time_ms" + + df = pd.DataFrame( + data=data, + columns=[ + "m", + "k", + "n", + "recipe", + timing_notation, + ], + ) + + print(df, "\n") + return df + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + parser.add_argument( + "--output-dir", + type=str, + default="benchmark_output/", + help="output path for report", + ) + # arguments for recipe, options are fp8_sub_channel, mxfp8, bf16, all + parser.add_argument( + "--recipe", + type=str, + default="bf16", + help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all", + ) + parser.add_argument( + "--token-dim", + type=int, + default=None, + help="Token dimension to use, calculated by SEQ_LEN * MBS / TP_SIZE", + ) + parser.add_argument( + "--hidden-dim", + type=int, + default=None, + help="Hidden dimension to use", + ) + parser.add_argument( + "--output-dim", + type=int, + default=None, + help="Output dimension to use", + ) + parser.add_argument( + "--fwd-only", + action="store_true", + default=False, + help="Run forward pass only, default is both forward and backward passes", + ) + args = parser.parse_args() + + use_bias = False + + token_dim_list = [16384] + hidden_dim_list = [4096] + output_dim_list = [4096] + + if args.token_dim is not None: + token_dim_list = [args.token_dim] + + if args.hidden_dim is not None: + hidden_dim_list = [args.hidden_dim] + + if args.output_dim is not None: + output_dim_list = [args.output_dim] + + # MKN for linear + mkns = [] + for m in token_dim_list: + for k in hidden_dim_list: + for n in output_dim_list: + mkns.append((m, k, n)) + + # default recipes to run if not specified + recipe_list = ["bf16"] + + if args.recipe == "all": + recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"] + else: + recipe_list = [args.recipe] + + if args.profile: + hidden_dim_to_profile = 4096 if args.hidden_dim is None else args.hidden_dim + output_dim_to_profile = 4096 if args.output_dim is None else args.output_dim + token_dim_to_profile = 16384 if args.token_dim is None else args.token_dim + mkns = [(token_dim_to_profile, hidden_dim_to_profile, output_dim_to_profile)] + # in profile mode, only run one recipe specified in args.recipe + assert args.recipe != "all", ( + "In profile mode, only one recipe can be specified, please specify the recipe as" + " fp8_sub_channel, mxfp8, nvfp4, or bf16" + ) + recipe_list = [args.recipe] + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + + # Initialize a dataframe to store the results + df_linears = pd.DataFrame() + + # Run the fp8 benchmarks + for recipe_name in recipe_list: + assert recipe_name in [ + "bf16", + "fp8_sub_channel", + "mxfp8", + "nvfp4", + ], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4" + if recipe_name == "mxfp8" and not mxfp8_available: + print(f"MXFP8 is not available, skipping {recipe_name}") + continue + if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available: + print(f"FP8 block scaling is not available, skipping {recipe_name}") + continue + if recipe_name == "nvfp4" and not nvfp4_available: + print(f"NVFP4 is not available, skipping {recipe_name}") + continue + + df = run_benchmark_linear( + mkns, + recipe_name, + use_bias, + fwd_only=args.fwd_only, + ) + df_linears = pd.concat([df_linears, df]) + + print(df_linears) + + if args.profile: + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) From 044a83454c870fae26d0b3f9f2ffb2416f4a7b93 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 2 Jan 2026 20:16:17 -0800 Subject: [PATCH 05/22] lint and format Signed-off-by: Zhongbo Zhu --- benchmarks/linear/benchmark_linear.py | 9 ++------- ..._cast_col_hadamard_transform_cast_fusion.cu | 18 ++++++------------ 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/benchmarks/linear/benchmark_linear.py b/benchmarks/linear/benchmark_linear.py index 3bbcb804ed..b293c44fc9 100644 --- a/benchmarks/linear/benchmark_linear.py +++ b/benchmarks/linear/benchmark_linear.py @@ -140,10 +140,7 @@ def benchmark_linear( label = f"{recipe_name}_{'linear'}" torch.cuda.nvtx.range_push(label) timing = benchmark.Timer( - stmt=( - "run_linear_multiple_steps(layer, x, mode, gradient, num_microbatches," - " recipe)" - ), + stmt="run_linear_multiple_steps(layer, x, mode, gradient, num_microbatches, recipe)", globals={ "run_linear_multiple_steps": run_linear_multiple_steps, "layer": layer, @@ -161,9 +158,7 @@ def benchmark_linear( return timing_ms -def run_benchmark_linear( - mkns, recipe_name, use_bias, fwd_only=False -): +def run_benchmark_linear(mkns, recipe_name, use_bias, fwd_only=False): data = [] assert not use_bias, "Bias is not supported in this benchmark script" diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 99fa711b3a..4ca8348bfb 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -436,8 +436,7 @@ __global__ static void row_col_rht_gemm_device( typename CLCPipeline::Params clc_pipeline_params; if (is_sched_warp) { clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; - } - else { + } else { clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; } clc_pipeline_params.producer_blockid = 0; @@ -549,9 +548,7 @@ __global__ static void row_col_rht_gemm_device( scheduler.update_work_tile_info(); } while (scheduler.is_valid()); mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); - } - - else if (is_mma_warp) { + } else if (is_mma_warp) { cutlass::arch::warpgroup_reg_dealloc<32>(); if constexpr (kEnableRHTColQuant) { Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) @@ -615,8 +612,7 @@ __global__ static void row_col_rht_gemm_device( accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); } - } - else if(is_sched_warp) { + } else if(is_sched_warp) { cutlass::arch::warpgroup_reg_dealloc<32>(); do { clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); @@ -627,8 +623,7 @@ __global__ static void row_col_rht_gemm_device( ++clc_pipeline_consumer_state; scheduler.update_work_tile_info(); } while (scheduler.is_valid()); - } - else if (is_epilogue_col_quant_warp) { + } else if (is_epilogue_col_quant_warp) { cutlass::arch::warpgroup_reg_alloc<192>(); if constexpr (kEnableRHTColQuant) { using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; @@ -848,8 +843,7 @@ __global__ static void row_col_rht_gemm_device( scheduler.update_work_tile_info(); } while (scheduler.is_valid()); } - } - else if (is_epilogue_row_quant_warp) { + } else if (is_epilogue_row_quant_warp) { cutlass::arch::warpgroup_reg_alloc<136>(); if constexpr (kEnableRowQuant) { using S2RVectorType = uint128_t; @@ -1008,7 +1002,7 @@ __global__ static void row_col_rht_gemm_device( } else { cutlass::arch::warpgroup_reg_dealloc<32>(); } -} +} // NOLINT(readability/fn_size) // this function computes RHT-GEMM for From 1d45051d47e83b3e8494d9a37cc6848d445cc9d2 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 2 Jan 2026 20:50:42 -0800 Subject: [PATCH 06/22] compile guard Signed-off-by: Zhongbo Zhu --- ..._cast_col_hadamard_transform_cast_fusion.cu | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 4ca8348bfb..6231f2d79b 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -179,6 +179,24 @@ __global__ static void row_col_rht_gemm_device( float const* c_global_amax, const size_t* rng_state) { using namespace cute; + + // Abort immediately if compilation is not supported + constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY; + if constexpr (!is_blackwell_arch) { + NVTE_DEVICE_ERROR( + "row_col_rht_gemm_device is only supported on Blackwell " + "with architecture-specific compilation. " + "Try recompiling with sm_100a or similar."); + return; + } + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + "row_col_rht_gemm_device must generate row-wise " + "and/or column-wise output."); +#if !defined(CUTLASS_ARCH_CLC_ENABLED) + CUTLASS_NOT_IMPLEMENTED(); + return; +#endif + using X = Underscore; // static constexpr bool kApplyStochasticRounding = true; using ElementAccumulator = float; From 95dfaf35730c71b90d69c7e1b33c4e195c15fec6 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 9 Jan 2026 16:39:07 -0800 Subject: [PATCH 07/22] warning fix Signed-off-by: Zhongbo Zhu --- ...cast_col_hadamard_transform_cast_fusion.cu | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 6231f2d79b..2eace1ce00 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -841,13 +841,13 @@ __global__ static void row_col_rht_gemm_device( } - copy_if(tiled_r2g, - [&](auto coord){ - Tensor pSrc_view = group_modes<1,rank(pSrc)>(pSrc); - return elem_less(pSrc_view(_0{},coord), shape(mD)); - }, - src, dst); - // 32bit vectorization copy 4 e4m3 SFD for per 64/(16,4):(0, 1) element + Tensor pred_pSrc = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(dst), _1{})), [&](auto coord){ + Tensor pSrc_view = group_modes<1,rank(pSrc)>(pSrc); + return elem_less(pSrc_view(_0{},coord), shape(mD)); + }); + copy_if(tiled_r2g, pred_pSrc, src, dst); + // 32bit vectorization copy 4 e4m3 SFD for per 64 or(16,4):(0, 1) element + constexpr int vec_len = 32 / sizeof_bits_v; Tensor tDrSFD_v = recast>(tDrSFD); Tensor tDgSFD_v = recast>(tDgSFD); @@ -995,13 +995,12 @@ __global__ static void row_col_rht_gemm_device( } } - copy_if(tiled_r2g_QA, - [&](auto coord){ - Tensor tQApQA_view = group_modes<1,rank(tQApQA_mn)>(tQApQA_mn); - return elem_less(tQApQA_view(_0{}, coord), shape(mQA)); - }, - tQArQA, tQAgQA_mn); - // 32bit vectorization copy 4 e4m3 SFA for per 64/(16,4):(0, 1) element + Tensor pred_tQApQA = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(tQAgQA_mn), _1{})), [&](auto coord){ + Tensor tQApQA_view = group_modes<1,rank(tQApQA_mn)>(tQApQA_mn); + return elem_less(tQApQA_view(_0{}, coord), shape(mQA)); + }); + copy_if(tiled_r2g_QA, pred_tQApQA, tQArQA, tQAgQA_mn); + // 32bit vectorization copy 4 e4m3 SFA for per 64 or (16,4):(0, 1) element constexpr int vec_len = 32 / sizeof_bits_v; Tensor tQArSFA_v = recast>(filter(tQArSFA)); Tensor tQAgSFA_v = recast>(filter(tQAgSFA_mn)); From 60d17768fd9482ba84f44c7272fab9a5494834d6 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Fri, 9 Jan 2026 17:12:35 -0800 Subject: [PATCH 08/22] resolve greptile comment Signed-off-by: Zhongbo Zhu --- transformer_engine/pytorch/csrc/quantizer.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1a9af70680..f79f4d0f2c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2118,8 +2118,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou auto stream = at::cuda::getCurrentCUDAStream(); QuantizationConfigWrapper quant_config; + QuantizationConfigWrapper quant_config_columnwise; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); + quant_config_columnwise.set_noop_tensor(noop_flag->data()); } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); @@ -2141,7 +2143,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // we need separate RNG states for each to ensure they use different random numbers. TensorWrapper te_rng_state; TensorWrapper te_rng_state_columnwise; - QuantizationConfigWrapper quant_config_columnwise; // Only need a separate rng state when: // 1. Stochastic rounding is enabled @@ -2174,6 +2175,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou te_rng_state_columnwise = makeTransformerEngineTensor(rng_state_columnwise); quant_config_columnwise.set_stochastic_rounding(true); quant_config_columnwise.set_rng_state(te_rng_state_columnwise.data()); + quant_config_columnwise.set_nvfp4_2d_quantization(this->with_2d_quantization); } } From 62df6237f73273675b95a57b32ffe834c8bf8998 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 28 Jan 2026 11:04:40 -0800 Subject: [PATCH 09/22] minor style fixes Signed-off-by: Zhongbo Zhu --- .../row_cast_col_hadamard_transform_cast_fusion.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 2eace1ce00..c64c600dbe 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -361,7 +361,7 @@ __global__ static void row_col_rht_gemm_device( - // Allocate SMEMork + // Allocate SMEM extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); @@ -1174,8 +1174,8 @@ void row_col_rht_gemm_ntt_w_sfc( mma); // Assert checks problem size should be multiple of 64 - assert(M % 64 == 0); - assert(N % 64 == 0); + NVTE_CHECK(M % 64 == 0, "M must be a multiple of 64, but got ", M); + NVTE_CHECK(N % 64 == 0, "N must be a multiple of 64, but got ", N); uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile_shape)))); uint32_t tiles_in_n = uint32_t(size(ceil_div(N, k_tile_size))); From 675d4de78968f6de86fb6021f09b1fbfa3609f69 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Wed, 28 Jan 2026 16:38:12 -0800 Subject: [PATCH 10/22] fix namespace Signed-off-by: Zhongbo Zhu --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 2 +- transformer_engine/common/include/transformer_engine/gemm.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index b3e216dc4f..899e354c18 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -639,7 +639,7 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, cudaStream_t stream) { NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", - CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); + CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); } #endif // CUBLAS_VERSION >= 130200 diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 7403448722..92713a5ba3 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -299,7 +299,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. + * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Blackwell (SM100) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on * a pre-Blackwell GPU. * @@ -322,7 +322,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * \param[in] stream CUDA stream for the operation. * * Requirements: - * - cuBLAS 13.2+ (CUDA 13.1+) + * - cuBLAS 13.2+ (CUDA 13.2+) * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] From 3942c6c2254420bfa75cb044be82edf09a43c7b9 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 24 Feb 2026 17:53:21 -0800 Subject: [PATCH 11/22] resolve some comments Signed-off-by: Zhongbo Zhu --- ...cast_col_hadamard_transform_cast_fusion.cu | 189 +++++++++--------- 1 file changed, 98 insertions(+), 91 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index c64c600dbe..868404df05 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -48,16 +48,19 @@ namespace { using namespace cute; -// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor -using cute::Tensor; - struct CLCResponse { uint32_t data[4] = {0}; }; +constexpr int kFp4ConvertChunkElements = 8; +constexpr int kFp4ConvertFullElements = 16; +constexpr int kFp4RbitsPerChunk = 2; +constexpr int kFp4ChunkCount = kFp4ConvertFullElements / kFp4ConvertChunkElements; + CUTLASS_DEVICE -cutlass::Array StochasticNumericConverterBase( - cutlass::Array const &input, cutlass::Array const &rbits) { - using result_type = cutlass::Array; +cutlass::Array StochasticNumericConverterBase( + cutlass::Array const &input, + cutlass::Array const &rbits) { + using result_type = cutlass::Array; result_type output; auto output_ptr = reinterpret_cast(&output); constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; @@ -79,15 +82,19 @@ cutlass::Array StochasticNumericConverterBase( } CUTLASS_DEVICE -cutlass::Array -StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const &rbits) { - using result_type = cutlass::Array; +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, + cutlass::Array const &rbits) { + using result_type = cutlass::Array; result_type output; - cutlass::Array *result_ptr = reinterpret_cast *>(&output); - cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); - cutlass::Array const *rbits_ptr = reinterpret_cast const *>(&rbits); + cutlass::Array *result_ptr = + reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = + reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = + reinterpret_cast const *>(&rbits); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; i++) { + for (int i = 0; i < kFp4ChunkCount; i++) { result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); } return output; @@ -491,19 +498,19 @@ __global__ static void row_col_rht_gemm_device( if (is_dma_warp) { cutlass::arch::warpgroup_reg_dealloc<32>(); - Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); - Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + cute::Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + cute::Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); - Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); - Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + cute::Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + cute::Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) - Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + cute::Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + cute::Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx - Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) - Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + cute::Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + cute::Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) Layout cta_layout_mnk = make_layout(cluster_shape); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); @@ -569,14 +576,14 @@ __global__ static void row_col_rht_gemm_device( } else if (is_mma_warp) { cutlass::arch::warpgroup_reg_dealloc<32>(); if constexpr (kEnableRHTColQuant) { - Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) - Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + cute::Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + cute::Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx // Allocate "fragments" -- these are actually umma smem descriptors - Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + cute::Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + cute::Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) mma.accumulate_ = UMMA::ScaleOut::Zero; @@ -673,28 +680,28 @@ __global__ static void row_col_rht_gemm_device( rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; } - Tensor mD = make_tensor( + cute::Tensor mD = make_tensor( cute::subbyte_iterator(D), make_shape(M,N), dD); // (M,N) - Tensor gD_mn = local_tile( + cute::Tensor gD_mn = local_tile( mD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) - Tensor pD = make_identity_tensor(mD.shape()); - Tensor pD_mn = local_tile( + cute::Tensor pD = make_identity_tensor(mD.shape()); + cute::Tensor pD_mn = local_tile( pD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) - Tensor mSFD = make_tensor(make_gmem_ptr(SFD), sfd_layout); - Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) - Tensor pSFD = make_identity_tensor(mSFD.shape()); - Tensor pSFD_mn = local_tile(pSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + cute::Tensor mSFD = make_tensor(make_gmem_ptr(SFD), sfd_layout); + cute::Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + cute::Tensor pSFD = make_identity_tensor(mSFD.shape()); + cute::Tensor pSFD_mn = local_tile(pSFD, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) - Tensor gD_mn_view = tiled_divide(gD_mn, take<0,2>(epilogue_tiler)); - Tensor pD_mn_view = tiled_divide(pD_mn, take<0,2>(epilogue_tiler)); + cute::Tensor gD_mn_view = tiled_divide(gD_mn, take<0,2>(epilogue_tiler)); + cute::Tensor pD_mn_view = tiled_divide(pD_mn, take<0,2>(epilogue_tiler)); auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); auto tiled_r2g = make_tiled_copy_D( Copy_Atom{}, @@ -724,41 +731,41 @@ __global__ static void row_col_rht_gemm_device( scheduler.fetch_next_work(clc_pipeline, clc_pipeline_consumer_state); ++clc_pipeline_consumer_state; for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); ++k_tile) { - Tensor tDgD_mn = gD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); - Tensor tDgSFD_mn = gSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); - Tensor tDpD_mn = pD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); - Tensor tDpSFD_mn = pSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + cute::Tensor tDgD_mn = gD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + cute::Tensor tDgSFD_mn = gSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + cute::Tensor tDpD_mn = pD_mn_view(_,_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + cute::Tensor tDpSFD_mn = pSFD_mn(_,_,scheduler.tile_m(), scheduler.tile_n_base() + k_tile); accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); auto Acc = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); - Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDpD = thr_t2r.partition_D(tDpD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tTR_rAcc = make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDrD = make_tensor(shape(tDgD)); - Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); - Tensor tDrD_frag = recast>(coalesce(tDrD)); - - Tensor src = thr_r2g.retile_S(tDrD); - Tensor dst = thr_r2g.retile_D(tDgD); - Tensor pSrc = thr_r2g.retile_D(tDpD); - - Tensor tDgSFD_view = make_tensor( + cute::Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + cute::Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + cute::Tensor tDpD = thr_t2r.partition_D(tDpD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + cute::Tensor tTR_rAcc = make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + cute::Tensor tDrD = make_tensor(shape(tDgD)); + cute::Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + cute::Tensor tDrD_frag = recast>(coalesce(tDrD)); + + cute::Tensor src = thr_r2g.retile_S(tDrD); + cute::Tensor dst = thr_r2g.retile_D(tDgD); + cute::Tensor pSrc = thr_r2g.retile_D(tDpD); + + cute::Tensor tDgSFD_view = make_tensor( tDgSFD_mn.data(), make_layout( make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); - Tensor tDpSFD_view = make_tensor( + cute::Tensor tDpSFD_view = make_tensor( tDpSFD_mn.data(), make_layout( make_shape(shape(tDpSFD_mn), Int<1>{}, Int<1>{}), make_stride(stride(tDpSFD_mn), Int<0>{}, Int<0>{}))); - Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); - Tensor tDrSFD = make_tensor(shape(tDgSFD)); - Tensor tDpSFD = filter(thr_t2r.partition_D(tDpSFD_view)); + cute::Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + cute::Tensor tDrSFD = make_tensor(shape(tDgSFD)); + cute::Tensor tDpSFD = filter(thr_t2r.partition_D(tDpSFD_view)); static int constexpr NumVecs = size(tDgD) / VectorSize; - Tensor tD_rRowSFD_frg = recast>(tDrSFD); + cute::Tensor tD_rRowSFD_frg = recast>(tDrSFD); cutlass::maximum_absolute_value_reduction, true> amax_reduction; cutlass::Array vec_maxs; @@ -841,19 +848,19 @@ __global__ static void row_col_rht_gemm_device( } - Tensor pred_pSrc = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(dst), _1{})), [&](auto coord){ - Tensor pSrc_view = group_modes<1,rank(pSrc)>(pSrc); + cute::Tensor pred_pSrc = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(dst), _1{})), [&](auto coord){ + cute::Tensor pSrc_view = group_modes<1,rank(pSrc)>(pSrc); return elem_less(pSrc_view(_0{},coord), shape(mD)); }); copy_if(tiled_r2g, pred_pSrc, src, dst); // 32bit vectorization copy 4 e4m3 SFD for per 64 or(16,4):(0, 1) element constexpr int vec_len = 32 / sizeof_bits_v; - Tensor tDrSFD_v = recast>(tDrSFD); - Tensor tDgSFD_v = recast>(tDgSFD); + cute::Tensor tDrSFD_v = recast>(tDrSFD); + cute::Tensor tDgSFD_v = recast>(tDgSFD); copy_if( [&](auto coord){ - Tensor tDpSFD_view = group_modes<1,rank(tDpSFD)>(tDpSFD); + cute::Tensor tDpSFD_view = group_modes<1,rank(tDpSFD)>(tDpSFD); return elem_less(tDpSFD_view(_0{}, coord * vec_len), shape(mSFD)); }, tDrSFD_v, tDgSFD_v); @@ -874,16 +881,16 @@ __global__ static void row_col_rht_gemm_device( rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; } - Tensor mQA = make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, N), dQA)); - Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); - Tensor pQA = make_identity_tensor(mQA.shape()); - Tensor pQA_mn = local_tile(pQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); - - Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); - Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); // (BLK_M,BLK_N) - Tensor pSFA = make_identity_tensor(mSFA.shape()); - Tensor pSFA_mn = local_tile(pSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); - Tensor sA = as_position_independent_swizzle_tensor( + cute::Tensor mQA = make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, N), dQA)); + cute::Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + cute::Tensor pQA = make_identity_tensor(mQA.shape()); + cute::Tensor pQA_mn = local_tile(pQA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + + cute::Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + cute::Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); // (BLK_M,BLK_N) + cute::Tensor pSFA = make_identity_tensor(mSFA.shape()); + cute::Tensor pSFA_mn = local_tile(pSFA, epilogue_tiler, make_coord(_,_, _), Step<_1,X,_1>{}); + cute::Tensor sA = as_position_independent_swizzle_tensor( group_modes<0,2>(coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) using S2RWarpLayout = Layout>; using WarpGroupLayout = Layout>; @@ -898,17 +905,17 @@ __global__ static void row_col_rht_gemm_device( auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); - Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + cute::Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) - Tensor tQArA = make_tensor_like(make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + cute::Tensor tQArA = make_tensor_like(make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) // Tensor tQArA_PI = thr_s2r.partition_S(sA_PI); - Tensor tQAgQA = thr_r2g_QA.partition_D(gQA_mn); - Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); - Tensor tQApQA = thr_r2g_QA.partition_D(pQA_mn); + cute::Tensor tQAgQA = thr_r2g_QA.partition_D(gQA_mn); + cute::Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + cute::Tensor tQApQA = thr_r2g_QA.partition_D(pQA_mn); - Tensor tQAgSFA = thr_s2r.partition_D(gSFA_mn); - Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); - Tensor tQApSFA = thr_s2r.partition_D(pSFA_mn); + cute::Tensor tQAgSFA = thr_s2r.partition_D(gSFA_mn); + cute::Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + cute::Tensor tQApSFA = thr_s2r.partition_D(pSFA_mn); // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} static constexpr float fp4_max = 6.0f; @@ -995,18 +1002,18 @@ __global__ static void row_col_rht_gemm_device( } } - Tensor pred_tQApQA = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(tQAgQA_mn), _1{})), [&](auto coord){ - Tensor tQApQA_view = group_modes<1,rank(tQApQA_mn)>(tQApQA_mn); + cute::Tensor pred_tQApQA = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(tQAgQA_mn), _1{})), [&](auto coord){ + cute::Tensor tQApQA_view = group_modes<1,rank(tQApQA_mn)>(tQApQA_mn); return elem_less(tQApQA_view(_0{}, coord), shape(mQA)); }); copy_if(tiled_r2g_QA, pred_tQApQA, tQArQA, tQAgQA_mn); // 32bit vectorization copy 4 e4m3 SFA for per 64 or (16,4):(0, 1) element constexpr int vec_len = 32 / sizeof_bits_v; - Tensor tQArSFA_v = recast>(filter(tQArSFA)); - Tensor tQAgSFA_v = recast>(filter(tQAgSFA_mn)); + cute::Tensor tQArSFA_v = recast>(filter(tQArSFA)); + cute::Tensor tQAgSFA_v = recast>(filter(tQAgSFA_mn)); copy_if( [&](auto coord){ - Tensor tQApSFA_view = filter(tQApSFA_mn); + cute::Tensor tQApSFA_view = filter(tQApSFA_mn); return elem_less(tQApSFA_view(_0{}, coord * vec_len), shape(mSFA)); }, tQArSFA_v, tQAgSFA_v); @@ -1075,12 +1082,12 @@ void row_col_rht_gemm_ntt_w_sfc( // Define shapes (dynamic) auto M = hidden_size; auto N = sequence_length; - Tensor tensorA = make_tensor(A, make_shape(hidden_size, sequence_length), LayoutLeft{}); - Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); - Tensor tensorD = make_tensor(D, make_shape(hidden_size, sequence_length), LayoutRight{}); - Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, sequence_length), LayoutLeft{}); - Tensor tensorSFD = make_tensor(SFD, sfd_layout); - Tensor tensorSFA = make_tensor(SFA, sfa_layout); + cute::Tensor tensorA = make_tensor(A, make_shape(hidden_size, sequence_length), LayoutLeft{}); + cute::Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{}); + cute::Tensor tensorD = make_tensor(D, make_shape(hidden_size, sequence_length), LayoutRight{}); + cute::Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, sequence_length), LayoutLeft{}); + cute::Tensor tensorSFD = make_tensor(SFD, sfd_layout); + cute::Tensor tensorSFA = make_tensor(SFA, sfa_layout); // Define strides (from tensors) auto dA = stride(tensorA); // (dM,dK) auto dB = stride(tensorB); // (dN,dK) From 50631ac30bab54e13cc054a64a53be86c868f32f Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 12:05:42 -0800 Subject: [PATCH 12/22] fix comment Signed-off-by: Zhongbo Zhu --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 2 +- transformer_engine/common/include/transformer_engine/gemm.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 899e354c18..b3e216dc4f 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -639,7 +639,7 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, cudaStream_t stream) { NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ", - CUBLAS_VERSION, ". Please upgrade to CUDA 13.2 or newer."); + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } #endif // CUBLAS_VERSION >= 130200 diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 92713a5ba3..7403448722 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -299,7 +299,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C * - * \note Requires cuBLAS 13.2+ (CUDA 13.2+) and Blackwell (SM100) or newer GPU architecture. + * \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. * Will error at runtime if compiled with an older cuBLAS version or run on * a pre-Blackwell GPU. * @@ -322,7 +322,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor * \param[in] stream CUDA stream for the operation. * * Requirements: - * - cuBLAS 13.2+ (CUDA 13.2+) + * - cuBLAS 13.2+ (CUDA 13.1+) * - Blackwell (SM100) or newer GPU architecture * - A, B, C (if provided), D must have the same num_tensors * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] From ee63059ddf94413eaca510b041fdc3d171e0f5ed Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 12:06:21 -0800 Subject: [PATCH 13/22] attempt to fix compile CI with guard Signed-off-by: Zhongbo Zhu --- .../group_row_cast_col_hadamard_transform_cast_fusion.cu | 5 +++-- .../row_cast_col_hadamard_transform_cast_fusion.cu | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 1ef1f81e82..45ab8c501b 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -190,8 +190,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( "with architecture-specific compilation. " "Try recompiling with sm_100a or similar."); return; - } - static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + } else { + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, "group_row_col_rht_gemm_device must generate row-wise " "and/or column-wise output."); #if !defined(CUTLASS_ARCH_CLC_ENABLED) @@ -1117,6 +1117,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( } else { cutlass::arch::warpgroup_reg_dealloc<32>(); } + } // sm100 compile guard end } // NOLINT(readability/fn_size) template (); } + } // sm100 compile guard end } // NOLINT(readability/fn_size) From 2f97846bd07d7cb5f319481b29d0dcedfe990913 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:07:17 +0000 Subject: [PATCH 14/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ...cast_col_hadamard_transform_cast_fusion.cu | 1728 +++++++++-------- 1 file changed, 867 insertions(+), 861 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu index 45ab8c501b..0bc72d6e11 100644 --- a/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -192,932 +192,938 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device( return; } else { static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, - "group_row_col_rht_gemm_device must generate row-wise " - "and/or column-wise output."); + "group_row_col_rht_gemm_device must generate row-wise " + "and/or column-wise output."); #if !defined(CUTLASS_ARCH_CLC_ENABLED) - CUTLASS_NOT_IMPLEMENTED(); - return; + CUTLASS_NOT_IMPLEMENTED(); + return; #endif - using X = Underscore; - // Accumulator data type for main computation - using ElementAccumulator = float; - static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); - using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; - static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( - size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); - static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; - static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; - static constexpr bool kEnableRowQuant = kEnableRowQuant_; - static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; - static constexpr bool kUseFastMath = kUseFastMath_; - - // Constant for RHT tensor processing (tile size etc) - static int constexpr RhtTensorSize = 16; - - // Transaction bytes for TMA transfer on RHT tensor blocks - static int constexpr kTmaRhtTensorTransactionBytes = - cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); - static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; - static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; - - // Mainloop pipeline stage calculation, vectorization parameters for scaling factors - static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); - static int constexpr SFVecSize = 16; - // Swizzle output layout for scaling factor arrays - using SwizzledSFALayoutAtom = - cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; - using SwizzledSFDLayoutAtom = - cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; - - // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling - using MainloopPipeline = - cutlass::detail::CustomizedPipelineTmaUmmaAsync; - using MainloopPipelineState = typename MainloopPipeline::PipelineState; - using SchedPipeline = cutlass::PipelineCLCFetchAsync; - using SchedPipelineState = typename SchedPipeline::PipelineState; - using SchedThrottlePipeline = cutlass::PipelineAsync; - using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; - - static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); - - using TmemAllocator = cute::TMEM::Allocator1Sm; - static int constexpr VectorSize = RhtTensorSize; - - // Compile-time safety: static shapes required for shared memory layouts - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - // CUTE_STATIC_ASSERT(is_static::value); - - auto cluster_size = size<0>(cluster_shape); - auto mainloop_tiler = Shape<_128, _16, _128>{}; - auto epilogue_tiler = Shape<_128, _128, _128>{}; - - static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); - - // Get the appropriate blocks for this Cluster - dim3 cluster_coord_in_grid = cluster_id_in_grid(); - - // Total number of k-tiles - int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); - - struct TileScheduler { - uint32_t tiles_in_m = 0; - uint32_t tiles_in_n = 0; - uint32_t linear_idx = 0; - uint32_t next_linear_idx = 0; - uint32_t start_idx = 0; - uint32_t tile_m_idx = 0; - uint32_t tile_n_idx = 0; - int k_tile_max = 0; - uint32_t *atomic_tile_index_; - uint32_t *smem_tile_counter; - uint32_t atomic_offset; - cutlass::FastDivmodU64 divmod_tiles_in_m; - - CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, - uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) - : tiles_in_m(tiles_m), - tiles_in_n(tiles_n), - linear_idx(blockIdx.x), - next_linear_idx(blockIdx.x), - start_idx(blockIdx.x), - k_tile_max(kmax), - atomic_tile_index_(atomic_tile_index), - smem_tile_counter(smem_tile_counter), - atomic_offset(gridDim.x), - divmod_tiles_in_m(uint64_t(tiles_m)) { - update_tile_idx(); - } - CUTLASS_DEVICE void update_tile_idx() { - uint64_t q, r; - divmod_tiles_in_m(q, r, uint64_t(linear_idx)); - tile_m_idx = static_cast(r); - tile_n_idx = static_cast(q) * uint32_t(k_tile_max); - } - CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } - CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } - CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } + using X = Underscore; + // Accumulator data type for main computation + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; + static constexpr bool kUseFastMath = kUseFastMath_; + + // Constant for RHT tensor processing (tile size etc) + static int constexpr RhtTensorSize = 16; + + // Transaction bytes for TMA transfer on RHT tensor blocks + static int constexpr kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + // Mainloop pipeline stage calculation, vectorization parameters for scaling factors + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + static int constexpr SFVecSize = 16; + // Swizzle output layout for scaling factor arrays + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + + // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineState = typename SchedPipeline::PipelineState; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + + // Compile-time safety: static shapes required for shared memory layouts + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // CUTE_STATIC_ASSERT(is_static::value); + + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128, _16, _128>{}; + auto epilogue_tiler = Shape<_128, _128, _128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); + + struct TileScheduler { + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + uint32_t linear_idx = 0; + uint32_t next_linear_idx = 0; + uint32_t start_idx = 0; + uint32_t tile_m_idx = 0; + uint32_t tile_n_idx = 0; + int k_tile_max = 0; + uint32_t *atomic_tile_index_; + uint32_t *smem_tile_counter; + uint32_t atomic_offset; + cutlass::FastDivmodU64 divmod_tiles_in_m; + + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, + uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + linear_idx(blockIdx.x), + next_linear_idx(blockIdx.x), + start_idx(blockIdx.x), + k_tile_max(kmax), + atomic_tile_index_(atomic_tile_index), + smem_tile_counter(smem_tile_counter), + atomic_offset(gridDim.x), + divmod_tiles_in_m(uint64_t(tiles_m)) { + update_tile_idx(); + } + CUTLASS_DEVICE void update_tile_idx() { + uint64_t q, r; + divmod_tiles_in_m(q, r, uint64_t(linear_idx)); + tile_m_idx = static_cast(r); + tile_n_idx = static_cast(q) * uint32_t(k_tile_max); + } + CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } + CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } + CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } - CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } + CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } - CUTLASS_DEVICE bool is_valid() const { - return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), - cute::make_coord(tiles_in_m, tiles_in_n)); - } - - CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } + CUTLASS_DEVICE bool is_valid() const { + return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), + cute::make_coord(tiles_in_m, tiles_in_n)); + } - CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } + CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } - // Fetch a new tile_id using atomics. - CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { - uint32_t tile_id_counter = 0; - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - "setp.eq.u32 p, %2, 1;\n\t" - "@p atom.global.add.u32 %0, [%1], 1; \n\t" - "}" - : "=r"(tile_id_counter) - : "l"(atomic_tile_index_), "r"(pred)); + CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } - return tile_id_counter; - } + // Fetch a new tile_id using atomics. + CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { + uint32_t tile_id_counter = 0; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p atom.global.add.u32 %0, [%1], 1; \n\t" + "}" + : "=r"(tile_id_counter) + : "l"(atomic_tile_index_), "r"(pred)); - CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, - SchedPipelineState sched_pipeline_consumer_state) { - sched_pipeline.consumer_wait(sched_pipeline_consumer_state); - next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; - cutlass::arch::fence_view_async_shared(); - sched_pipeline.consumer_release(sched_pipeline_consumer_state); - return; - } + return tile_id_counter; + } - CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, - SchedPipelineState sched_pipeline_producer_state) { - uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); - // Wait for clcID buffer to become empty with a flipped phase - sched_pipeline.producer_acquire(sched_pipeline_producer_state); - auto is_leading_thread = cute::elect_one_sync(); - uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; - uint32_t smem_addr = - cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); - if (is_leading_thread) { - cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_consumer_state) { + sched_pipeline.consumer_wait(sched_pipeline_consumer_state); + next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; + cutlass::arch::fence_view_async_shared(); + sched_pipeline.consumer_release(sched_pipeline_consumer_state); + return; } - ++sched_pipeline_producer_state; - return sched_pipeline_producer_state; - } + CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_producer_state) { + uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + sched_pipeline.producer_acquire(sched_pipeline_producer_state); + auto is_leading_thread = cute::elect_one_sync(); + uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; + uint32_t smem_addr = + cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); + if (is_leading_thread) { + cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + } - CUTLASS_DEVICE auto update_work_tile_info() { - linear_idx = next_linear_idx; - update_tile_idx(); - return; - } - }; - - // Allocate and alias shared memory to the kernel's shared storage type - extern __shared__ char shared_memory[]; - using SharedStorage = - SharedStorage; - SharedStorage &shared_storage = *reinterpret_cast(shared_memory); - - // Compute the number of tiles in M and N after tiling and assign scheduler - uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); - uint32_t tiles_in_n = uint32_t( - size(ceil_div(args.split_sections_range[args.num_tensors], size<2>(epilogue_tiler)))); - - TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, - shared_storage.atomic_tile_counter); - - int block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Shapes for accumulated tiles in mainloop and epilogue - auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); - auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); - - // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended - auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); - auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); - - // Number of threads assigned for various epilogue roles depending on quantization settings - static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; - static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; - static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; - static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; - static int constexpr NumSchedThreads = 32; - static int constexpr NumMainloopLoadThreads = 32; - static int constexpr NumEpilogueThreads = - NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; - - TmemAllocator tmem_allocator{}; - cutlass::arch::NamedBarrier tmem_allocation_result_barrier( - NumMmaThreadCount + NumEpilogueColQuantThreadCount, - cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); - - int warp_idx = cutlass::canonical_warp_idx_sync(); - - // warp assignment - bool is_mma_warp = (warp_idx == 0); - bool is_dma_warp = (warp_idx == 1); - bool is_sched_warp = (warp_idx == 2); - bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); - bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); - - typename MainloopPipeline::Params mainloop_pipeline_params; - if (is_dma_warp) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (is_mma_warp) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; - mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; - mainloop_pipeline_params.initializing_warp = 0; - mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + ++sched_pipeline_producer_state; + return sched_pipeline_producer_state; + } - MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, - cluster_shape, cute::true_type{}, // Perform barrier init - cute::true_type{}); // Delay mask calculation + CUTLASS_DEVICE auto update_work_tile_info() { + linear_idx = next_linear_idx; + update_tile_idx(); + return; + } + }; - MainloopPipelineState mainloop_pipe_consumer_state; - MainloopPipelineState mainloop_pipe_producer_state = - cutlass::make_producer_start_state(); + // Allocate and alias shared memory to the kernel's shared storage type + extern __shared__ char shared_memory[]; + using SharedStorage = + SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); - using AccumulatorPipeline = - cutlass::PipelineUmmaAsync; - using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; - using AccumulatorPipelineInitBarriers = cute::bool_constant; + // Compute the number of tiles in M and N after tiling and assign scheduler + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t( + size(ceil_div(args.split_sections_range[args.num_tensors], size<2>(epilogue_tiler)))); - AccumulatorPipelineState accumulator_pipe_consumer_state; - AccumulatorPipelineState accumulator_pipe_producer_state = - cutlass::make_producer_start_state(); + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, + shared_storage.atomic_tile_counter); - typename AccumulatorPipeline::Params accumulator_pipeline_params; - if (is_mma_warp) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; - } - if (is_epilogue_col_quant_warp) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; - } - // Only one producer thread arrives on this barrier. - accumulator_pipeline_params.producer_arv_count = 1; - accumulator_pipeline_params.consumer_arv_count = - size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; - accumulator_pipeline_params.initializing_warp = 1; - AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, - cluster_shape, AccumulatorPipelineInitBarriers{}, - cute::true_type{}); // Delay mask calculation - typename SchedPipeline::Params sched_pipeline_params; - if (is_sched_warp) { - sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; - } else { - sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; - } - sched_pipeline_params.producer_blockid = 0; - sched_pipeline_params.producer_arv_count = 1; - sched_pipeline_params.consumer_arv_count = - NumSchedThreads + - cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); - sched_pipeline_params.transaction_bytes = sizeof(uint32_t); - sched_pipeline_params.initializing_warp = 3; - SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); - SchedPipelineState sched_pipeline_consumer_state; - SchedPipelineState sched_pipeline_producer_state = - cutlass::make_producer_start_state(); - - typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; - if (is_dma_warp) { - sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; - } - if (is_sched_warp) { - sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; - } - sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - sched_throttle_pipeline_params.dst_blockid = 0; - sched_throttle_pipeline_params.initializing_warp = 4; - - SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, - sched_throttle_pipeline_params); - SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; - SchedThrottlePipelineState sched_pipeline_throttle_producer_state = - cutlass::make_producer_start_state(); - - if (warp_idx == 2 && elect_one_sync()) { - cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); - } - __syncthreads(); - - // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer - if (is_dma_warp) { - // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). - cutlass::arch::warpgroup_reg_dealloc<32>(); - // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. - Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); - Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); - - // Partition tensors for tiling according to the mainloop and cluster tilers. - Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); - Tensor gB_nk = - local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) - - // Shared memory tensors for pipeline - Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), - sAlayout); // (MMA,MMA_M,MMA_N,PIPE) - Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), - sBlayout); // (MMA,MMA_N,MMA_K,PIPE) - - // Determine warp/tile positioning int block_rank_in_cluster = cute::block_rank_in_cluster(); - ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx - // Partition global to local fragments for A and B - Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) - Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) - - Layout cta_layout_mnk = make_layout(cluster_shape); - Layout cta_layout_vmnk = - tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); - - auto [tAgA, tAsA] = - tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), - group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); - - auto [tBgB, tBsB] = - tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), - group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); - - uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); - uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - if constexpr (kEnableRHTColQuant) { - if (elect_one_sync()) { - cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], - kTmaRhtTensorTransactionBytes); - copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), - tBsB(_, 0)); - } - } - do { - // is_first_wave indicates whether this scheduler wave is the first among a group. - bool is_first_wave = scheduler.is_first_wave(); - uint32_t skip_wait = is_first_wave; - auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); - int k_tile = 0; - - sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); - sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); - ++sched_pipeline_throttle_producer_state; - CUTLASS_PRAGMA_NO_UNROLL - while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { - int k_tile_idx_n = scheduler.tile_n_base() + k_tile; - ++k_tile; - skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); - mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType *tma_barrier = - mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); - int write_stage = mainloop_pipe_producer_state.index(); - ++mainloop_pipe_producer_state; - if (cute::elect_one_sync()) { - copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), - tAsA(_, write_stage)); - } - } - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; - scheduler.update_work_tile_info(); - // scheduler.advance(); - } while (scheduler.is_valid()); - mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); - } else if (is_mma_warp) { - // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. - cutlass::arch::warpgroup_reg_dealloc<32>(); - if constexpr (kEnableRHTColQuant) { - // Setup shared memory fragments for A and B tiles. + // Shapes for accumulated tiles in mainloop and epilogue + auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); + + // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + // Number of threads assigned for various epilogue roles depending on quantization settings + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = + NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = + NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + using AccumulatorPipelineInitBarriers = cute::bool_constant; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = + size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline( + shared_storage.accumulator, accumulator_pipeline_params, cluster_shape, + AccumulatorPipelineInitBarriers{}, cute::true_type{}); // Delay mask calculation + typename SchedPipeline::Params sched_pipeline_params; + if (is_sched_warp) { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; + } else { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; + } + sched_pipeline_params.producer_blockid = 0; + sched_pipeline_params.producer_arv_count = 1; + sched_pipeline_params.consumer_arv_count = + NumSchedThreads + + cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + sched_pipeline_params.transaction_bytes = sizeof(uint32_t); + sched_pipeline_params.initializing_warp = 3; + SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); + SchedPipelineState sched_pipeline_consumer_state; + SchedPipelineState sched_pipeline_producer_state = + cutlass::make_producer_start_state(); + + typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; + if (is_dma_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; + } + sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + sched_throttle_pipeline_params.dst_blockid = 0; + sched_throttle_pipeline_params.initializing_warp = 4; + + SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, + sched_throttle_pipeline_params); + SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; + SchedThrottlePipelineState sched_pipeline_throttle_producer_state = + cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer + if (is_dma_warp) { + // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). + cutlass::arch::warpgroup_reg_dealloc<32>(); + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + // Partition tensors for tiling according to the mainloop and cluster tilers. + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + + // Shared memory tensors for pipeline Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + // Determine warp/tile positioning int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx - // Allocate "fragments" -- these are actually umma smem descriptors - Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) - - mma.accumulate_ = UMMA::ScaleOut::Zero; - - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, - &shared_storage.tmem_base_ptr); - __syncwarp(); - tmem_allocation_result_barrier.arrive(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - bulk_tmem_mma.data() = tmem_base_ptr; - // Wait until the B (Hadamard) tensor copy is complete - cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); - do { - uint32_t skip_wait = K_TILE_MAX <= 0; + // Partition global to local fragments for A and B + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = + tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } + } - auto barrier_token = - mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; + do { + // is_first_wave indicates whether this scheduler wave is the first among a group. + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); + int k_tile = 0; + + sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); + sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); + ++sched_pipeline_throttle_producer_state; CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { - mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - int read_stage = mainloop_pipe_consumer_state.index(); - auto tCrA_mk = tCrA(_, _, _, read_stage); - auto tCrB_nk = tCrB(_, _, 0, 0); - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { - int accumulator_k_block = - accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; - int tCrA_k_block = k_block * EpilogueUnrollFactor; - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < EpilogueUnrollFactor; i++) { - auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); - gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); - } - - accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); - ++accumulator_pipe_producer_state; - } - auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; - ++mainloop_pipe_consumer_state; + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; ++k_tile; - skip_wait = k_tile >= K_TILE_MAX; - mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); - barrier_token = - mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } } + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; scheduler.update_work_tile_info(); + // scheduler.advance(); } while (scheduler.is_valid()); - tmem_allocator.release_allocation_lock(); - accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); - tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); - } - } else if (is_sched_warp) { - // Scheduler warp manages tile assignment and pipeline progress for warps - cutlass::arch::warpgroup_reg_dealloc<32>(); - do { - sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); - sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); - ++sched_pipeline_throttle_consumer_state; - sched_pipeline_producer_state = - scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; - scheduler.update_work_tile_info(); - } while (scheduler.is_valid()); - } else if (is_epilogue_col_quant_warp) { - // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, - // and writing result tensors/scales to global memory. - cutlass::arch::warpgroup_reg_alloc<192>(); - if constexpr (kEnableRHTColQuant) { - using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; - - auto acc_epilogue_pipelined_shape = - append(acc_shape_epilogue, Int{}); - auto bulk_tmem_epilogue_layout = make_layout( - acc_epilogue_pipelined_shape, - make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); - auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); - - // Use 256-bit fragments for aligned bulk stores - static int constexpr FragmentSize = 256 / sizeof_bits_v; - - // Wait for TMEM allocation for this pipeline to finish - tmem_allocation_result_barrier.arrive_and_wait(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - bulk_tmem_epilogue.data() = tmem_base_ptr; - int global_thread_idx = threadIdx.x; - int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; - // g2s load all global_d_amax - CUTLASS_PRAGMA_NO_UNROLL - for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueColQuantThreadCount) { - shared_storage.global_d_amax[g] = - __ldg(reinterpret_cast(args.global_d_amax_list[g])); - } - - size_t rng_seed = 0; - size_t rng_offset = 0; - // Setup RNG for stochastic rounding - if constexpr (kEnableStochasticRounding) { - rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; - rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; - } - int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); - - // Determine quantization scale factor layouts/output splits for this group - TSFDLayout sfd_layout; - int cur_N = args.split_sections[group_idx]; - if constexpr (kEnableSwizzleSFOutput) { - sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); - } else { - sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), - make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); - } - // Build output tensors for columns and their quant scales - Tensor mD = make_tensor( - cute::subbyte_iterator(reinterpret_cast(args.output_colwise_list[group_idx])), - make_shape(M, cur_N), DStride{}); // (M,packed_N) - Tensor gD_mn = - local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N) - - Tensor mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( - args.output_colwise_scale_inv_list[group_idx])), - sfd_layout); - Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), - Step<_1, _1, X>{}); // (BLK_M,BLK_N) - - Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); - - // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors - auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); - auto tiled_r2g = - make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); - auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); - - cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} - static constexpr float fp4_max = 6.0f; - static constexpr float fp8_max = 448.0f; - static constexpr float fp4_max_inv = 1.0f / fp4_max; - float c_global_amax_val = shared_storage.global_d_amax[group_idx]; - float global_encode_scale = c_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / c_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - float global_decode_scale = 1.0f / global_encode_scale; - - // Scaling factor for fast math path - float global_encode_scale_multiplier = 1.0f; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + // Setup shared memory fragments for A and B tiles. + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + // Wait until the B (Hadamard) tensor copy is complete + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { + int accumulator_k_block = + accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); + gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); } - + } else if (is_sched_warp) { + // Scheduler warp manages tile assignment and pipeline progress for warps + cutlass::arch::warpgroup_reg_dealloc<32>(); do { + sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); + sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); + ++sched_pipeline_throttle_consumer_state; + sched_pipeline_producer_state = + scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } else if (is_epilogue_col_quant_warp) { + // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, + // and writing result tensors/scales to global memory. + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + auto acc_epilogue_pipelined_shape = + append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // Use 256-bit fragments for aligned bulk stores + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + // Wait for TMEM allocation for this pipeline to finish + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + // g2s load all global_d_amax CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); - ++k_tile) { - int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); - - int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); - - if (cur_group_idx != group_idx) { - group_idx = cur_group_idx; - c_global_amax_val = shared_storage.global_d_amax[group_idx]; - // update amax - global_encode_scale = c_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / c_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - global_decode_scale = 1.0f / global_encode_scale; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; - } - cur_N = args.split_sections[group_idx]; - if constexpr (kEnableSwizzleSFOutput) { - sfd_layout = - tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); - } else { - sfd_layout = - make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), - make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); - } - // update tensor - mD = make_tensor(cute::subbyte_iterator( - reinterpret_cast(args.output_colwise_list[group_idx])), - make_shape(M, cur_N), DStride{}); - gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), - Step<_1, _1, X>{}); // (BLK_M,BLK_N) - mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( - args.output_colwise_scale_inv_list[group_idx])), - sfd_layout); - gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), - Step<_1, _1, X>{}); // (BLK_M,BLK_N) + for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueColQuantThreadCount) { + shared_storage.global_d_amax[g] = + __ldg(reinterpret_cast(args.global_d_amax_list[g])); + } - gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); - } - int group_start_offset = args.split_sections_range[group_idx]; - int local_tile_n_idx = - (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); - Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); - - Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); - accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); - - auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); - Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - - Tensor tTR_rAcc = - make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDrD = make_tensor(shape(tDgD)); - Tensor tTR_rAcc_frag = - recast>(coalesce(tTR_rAcc)); - Tensor tDrD_frag = recast>(coalesce(tDrD)); - - Tensor src = thr_r2g.retile_S(tDrD); - Tensor dst = thr_r2g.retile_D(tDgD); - - Tensor tDgSFD_view = make_tensor( - tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), - make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); - Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); - Tensor tDrSFD = make_tensor(shape(tDgSFD)); - - static int constexpr NumVecs = size(tDgD) / VectorSize; - Tensor tD_rRowSFD_frg = recast>(tDrSFD); - - // Compute amax and quantization scales for this tile - cutlass::maximum_absolute_value_reduction, - true> - amax_reduction; - cutlass::Array vec_maxs; - cutlass::Array pvscales; - // Copy from TMEM to registers - copy(tiled_t2r, tDtAcc, tTR_rAcc); - cutlass::arch::fence_view_async_tmem_load(); - accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); - ++accumulator_pipe_consumer_state; - - if constexpr (!kUseFastMath) { - // Downcast to BF16 for bit-wise compatibility with - // unfused kernels - auto convert_accum_to_bf16 = - cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = - cutlass::NumericArrayConverter{}; - tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); - tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); - } + size_t rng_seed = 0; + size_t rng_offset = 0; + // Setup RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); + + // Determine quantization scale factor layouts/output splits for this group + TSFDLayout sfd_layout; + int cur_N = args.split_sections[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // Build output tensors for columns and their quant scales + Tensor mD = make_tensor( + cute::subbyte_iterator(reinterpret_cast(args.output_colwise_list[group_idx])), + make_shape(M, cur_N), DStride{}); // (M,packed_N) + Tensor gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) - auto compute_frgs = reinterpret_cast *>( - tTR_rAcc_frag.data()); - auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < NumVecs; v++) { - vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); - } + Tensor mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( + args.output_colwise_scale_inv_list[group_idx])), + sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + + // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float c_global_amax_val = shared_storage.global_d_amax[group_idx]; + float global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + float global_decode_scale = 1.0f / global_encode_scale; + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } - if constexpr (kUseFastMath) { - // Fast math: multiply with precomputed reciprocal - pvscales = cutlass::multiplies>{}( - vec_maxs, global_encode_scale_multiplier); - } else { - // Accurate math: perform division - pvscales = - cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}( - pvscales, global_encode_scale); - } - auto pvscales_cvted = - cutlass::NumericArrayConverter{}(pvscales); - - tD_rRowSFD_frg(_0{}) = pvscales_cvted; - auto qpvscale_ups = cutlass::NumericArrayConverter{}( - tD_rRowSFD_frg(_0{})); - auto qpvscale_scaled = cutlass::multiplies>{}( - qpvscale_ups, global_decode_scale); - cutlass::Array acc_scales; - if constexpr (kUseFastMath) { - // Fast math: compute approximate reciprocal - acc_scales = - cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); - } else { - // Accurate math: compute reciprocal with division - acc_scales = cutlass::divides>{}( - 1.0, qpvscale_scaled); - } + do { + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); + ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); + + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + c_global_amax_val = shared_storage.global_d_amax[group_idx]; + // update amax + global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + cur_N = args.split_sections[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = + tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = + make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // update tensor + mD = make_tensor(cute::subbyte_iterator( + reinterpret_cast(args.output_colwise_list[group_idx])), + make_shape(M, cur_N), DStride{}); + gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( + args.output_colwise_scale_inv_list[group_idx])), + sfd_layout); + gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) - // Prepare stochastic rounding random state if enabled - uint4 random_uint4 = uint4{0, 0, 0, 0}; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - // "Prefetch" a stochastic rounding state for the first tile - if constexpr (kEnableStochasticRounding) { - const size_t rng_sequence = global_thread_idx + k_tile * 512 + - scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; - rng.init(rng_seed, rng_sequence, rng_offset); - } - CUTLASS_PRAGMA_UNROLL - // Apply round/quantize to each fragment, with or without stochastic rounding - for (int v = 0; v < NumVecs; v++) { - auto acc_scale = cutlass::minimum_with_nan_propagation{}( - acc_scales[v], cutlass::platform::numeric_limits::max()); - if constexpr (kEnableStochasticRounding) { - random_uint4 = rng.generate4(); - output_frgs[v] = StochasticNumericConverter( - cutlass::multiplies>{}( - compute_frgs[v], acc_scale), - *reinterpret_cast *>(&random_uint4)); - } else { - output_frgs[v] = cutlass::NumericArrayConverter{}( - cutlass::multiplies>{}( - compute_frgs[v], acc_scale)); + gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); } - } - - // Write quantized FP4 tile and dequant scale to gmem - copy(tiled_r2g, src, dst); - copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); - } - scheduler.update_work_tile_info(); - } while (scheduler.is_valid()); - } - } else if (is_epilogue_row_quant_warp) { - // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. - cutlass::arch::warpgroup_reg_alloc<136>(); - if constexpr (kEnableRowQuant) { - using S2RVectorType = uint128_t; - - int global_thread_idx = threadIdx.x; - int local_thread_idx = global_thread_idx % 256; - size_t rng_seed = 0; - size_t rng_offset = 0; - // g2s load all global_a_amax for all groups/tensors - CUTLASS_PRAGMA_NO_UNROLL - for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueRowQuantThreadCount) { - shared_storage.global_a_amax[g] = - __ldg(reinterpret_cast(args.global_a_amax_list[g])); - } - // RNG for stochastic rounding - if constexpr (kEnableStochasticRounding) { - rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; - rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; - } - // Input/output tensors/partitions for row quant warp - Tensor mQA = - make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); - Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); - Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); - - Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_N) - // Swizzled shared memory A tile, with layout - Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( - coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), - sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) - - // Set up layouts for partitioning – tile-by-warp, with vector granularity - using S2RWarpLayout = Layout>; - using WarpGroupLayout = Layout>; - using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); - using S2RValLayout = Layout, _1>>; - using S2RAtomA = Copy_Atom; - using R2GAtomQA = Copy_Atom; - using R2GAtomSFA = Copy_Atom; - auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); - auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); - auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); - - auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); - auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); - auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); - Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) - - // Allocate temporary register tensors for copying quantization => output - Tensor tQArA = make_tensor_like( - make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) - Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); - Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); - - Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); - Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); - - // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 - // in order to go over the reserved named barrier count. - constexpr int row_quant_barrier_id = 2; - cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); - - int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); - float a_global_amax_val = shared_storage.global_a_amax[group_idx]; - // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} - static constexpr float fp4_max = 6.0f; - static constexpr float fp8_max = 448.0f; - static constexpr float fp4_max_inv = 1.0f / fp4_max; - float global_encode_scale = a_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / a_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - - float global_decode_scale = 1.0f / global_encode_scale; - float global_encode_scale_multiplier = 1.0f; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; - } - auto sfa_converter = cutlass::NumericConverter{}; - do { - CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { - int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); - - int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); - if (cur_group_idx != group_idx) { - group_idx = cur_group_idx; - a_global_amax_val = shared_storage.global_a_amax[group_idx]; - // Update group quantization parameters/scaling - global_encode_scale = a_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / a_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - global_decode_scale = 1.0f / global_encode_scale; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + int group_start_offset = args.split_sections_range[group_idx]; + int local_tile_n_idx = + (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); + Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); + + Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = make_tensor( + shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + // Compute amax and quantization scales for this tile + cutlass::maximum_absolute_value_reduction< + cutlass::Array, true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // Copy from TMEM to registers + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = + convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = + convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); } - } - auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); - auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); - auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); - mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); - cutlass::arch::fence_view_async_shared(); - mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); - ++mainloop_pipe_consumer_state; - ++k_tile; + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } - // static int constexpr NumVecs = size(tQArA) / VectorSize; - cutlass::maximum_absolute_value_reduction, - true> - amax_reduction; - auto compute_frgs = reinterpret_cast *>(tQArA.data()); - auto output_frgs = - reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); - Tensor amax = - make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); - Tensor pvscales = make_tensor_like(amax); - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - if constexpr (kEnableStochasticRounding) { - const size_t rng_sequence = global_thread_idx + k_tile * 512 + - scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + - tiles_in_m * tiles_in_n * K_TILE_MAX * 512; - rng.init(rng_seed, rng_sequence, rng_offset); - } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { - auto amax_view = group_modes<1, rank(amax)>(amax); - auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); - auto compute_frgs_up = - cutlass::NumericArrayConverter{}( - compute_frgs[v]); - amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); if constexpr (kUseFastMath) { // Fast math: multiply with precomputed reciprocal - pvscales_view(_0{}, v) = cutlass::multiplies{}( - amax_view(_0{}, v), global_encode_scale_multiplier); + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); } else { // Accurate math: perform division - pvscales_view(_0{}, v) = - cutlass::divides{}(amax_view(_0{}, v), fp4_max); - pvscales_view(_0{}, v) = cutlass::multiplies{}( - pvscales_view(_0{}, v), global_encode_scale); + pvscales = cutlass::divides>{}(vec_maxs, + fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); } - filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); - auto qpvscale_ups = - cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tD_rRowSFD_frg(_0{})); auto qpvscale_scaled = - cutlass::multiplies{}(qpvscale_ups, global_decode_scale); - ElementAccumulator acc_scales; + cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; if constexpr (kUseFastMath) { // Fast math: compute approximate reciprocal acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { // Accurate math: compute reciprocal with division - acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + acc_scales = cutlass::divides>{}( + 1.0, qpvscale_scaled); } - auto acc_scale = cutlass::minimum_with_nan_propagation{}( - acc_scales, cutlass::platform::numeric_limits::max()); + + // Prepare stochastic rounding random state if enabled uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + // "Prefetch" a stochastic rounding state for the first tile if constexpr (kEnableStochasticRounding) { - random_uint4 = rng.generate4(); - output_frgs[v] = StochasticNumericConverter( - cutlass::multiplies>{}( - compute_frgs_up, acc_scale), - *reinterpret_cast *>(&random_uint4)); - } else { - output_frgs[v] = - cutlass::NumericArrayConverter{}( - cutlass::multiplies>{}( - compute_frgs_up, acc_scale)); + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + // Apply round/quantize to each fragment, with or without stochastic rounding + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } } + + // Write quantized FP4 tile and dequant scale to gmem + copy(tiled_r2g, src, dst); + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); } - copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); - copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else if (is_epilogue_row_quant_warp) { + // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + // g2s load all global_a_amax for all groups/tensors + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueRowQuantThreadCount) { + shared_storage.global_a_amax[g] = + __ldg(reinterpret_cast(args.global_a_amax_list[g])); } - // scheduler.advance(); - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; - scheduler.update_work_tile_info(); - } while (scheduler.is_valid()); - } + // RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // Input/output tensors/partitions for row quant warp + Tensor mQA = + make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_N) + // Swizzled shared memory A tile, with layout + Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( + coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + + // Set up layouts for partitioning – tile-by-warp, with vector granularity + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + using R2GAtomSFA = Copy_Atom; + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + // Allocate temporary register tensors for copying quantization => output + Tensor tQArA = make_tensor_like( + make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + + Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + + // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 + // in order to go over the reserved named barrier count. + constexpr int row_quant_barrier_id = 2; + cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + + int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler)); + float a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float global_decode_scale = 1.0f / global_encode_scale; + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + auto sfa_converter = cutlass::NumericConverter{}; + do { + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset); + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Update group quantization parameters/scaling + global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + } - } else { - cutlass::arch::warpgroup_reg_dealloc<32>(); - } - } // sm100 compile guard end + auto tQAgSFA_mn = + tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + ++mainloop_pipe_consumer_state; + ++k_tile; + + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction< + cutlass::Array, true> + amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = reinterpret_cast *>( + raw_pointer_cast(tQArQA.data())); + Tensor amax = + make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); + Tensor pvscales = make_tensor_like(amax); + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + + tiles_in_m * tiles_in_n * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { + auto amax_view = group_modes<1, rank(amax)>(amax); + auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); + auto compute_frgs_up = + cutlass::NumericArrayConverter{}( + compute_frgs[v]); + amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales_view(_0{}, v) = cutlass::multiplies{}( + amax_view(_0{}, v), global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales_view(_0{}, v) = + cutlass::divides{}(amax_view(_0{}, v), fp4_max); + pvscales_view(_0{}, v) = cutlass::multiplies{}( + pvscales_view(_0{}, v), global_encode_scale); + } + filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); + auto qpvscale_ups = + cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = + cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = cutlass::reciprocal_approximate_ftz{}( + qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale)); + } + } + copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); + copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + } + // scheduler.advance(); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } + } // sm100 compile guard end } // NOLINT(readability/fn_size) template Date: Tue, 3 Mar 2026 14:12:36 -0800 Subject: [PATCH 15/22] better naming for tests Signed-off-by: Zhongbo Zhu --- .../test_mxfp8_group_quantize_graph_safe.py | 52 +++++++++---------- .../test_mxfp8_quantize_swizzle_fusion.py | 22 ++++---- tests/pytorch/nvfp4/nvfp4_utils.py | 4 +- .../nvfp4/test_nvfp4_group_quantize.py | 24 ++++----- .../test_nvfp4_group_quantize_graph_safe.py | 48 ++++++++--------- .../nvfp4/test_nvfp4_quantize_exact.py | 8 +-- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 40 +++++++------- 7 files changed, 99 insertions(+), 99 deletions(-) diff --git a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py index 3c197bc6f3..939b6b58b1 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py +++ b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py @@ -79,7 +79,7 @@ def reference_group_quantize( x: torch.Tensor, quantizers: list[MXFP8Quantizer], split_sections: list[int], - return_identity: bool, + return_rowwise: bool, return_transpose: bool, ) -> torch.Tensor: x_chunks = torch.split(x, split_sections) @@ -94,7 +94,7 @@ def reference_group_quantize( for i in range(len(x_chunks)): x_chunk = x_chunks[i] x_mxfp8_res = quantizers[i](x_chunk) - if return_identity: + if return_rowwise: x_qx.append(x_mxfp8_res._rowwise_data.view(dtype=torch.uint8)) x_sx.append(x_mxfp8_res._rowwise_scale_inv) else: @@ -133,7 +133,7 @@ def check_grouped_tensor_mxfp8_versus_reference( x_dtype: torch.dtype, M: int, N: int, - return_identity: bool, + return_rowwise: bool, return_transpose: bool, split_sections: list[int], optimize_for_gemm: bool = False, @@ -157,7 +157,7 @@ def check_grouped_tensor_mxfp8_versus_reference( quantizers = [ MXFP8Quantizer( fp8_dtype=te_dtype, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, ) for _ in range(len(split_sections)) @@ -169,14 +169,14 @@ def check_grouped_tensor_mxfp8_versus_reference( grouped_quantizer.optimize_for_gemm = optimize_for_gemm x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = reference_group_quantize( - x, quantizers, split_sections, return_identity, return_transpose + x, quantizers, split_sections, return_rowwise, return_transpose ) group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) # get a list of MXFP8 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() - if return_identity: + if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] @@ -229,7 +229,7 @@ def check_grouped_tensor_mxfp8_with_paged_stashing( x_dtype: torch.dtype, M: int, N: int, - return_identity: bool, + return_rowwise: bool, return_transpose: bool, split_sections: list[int], valid_M: int = None, @@ -258,7 +258,7 @@ def check_grouped_tensor_mxfp8_with_paged_stashing( quantizers = [ MXFP8Quantizer( fp8_dtype=te_dtype, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, ) for _ in range(len(split_sections)) @@ -270,7 +270,7 @@ def check_grouped_tensor_mxfp8_with_paged_stashing( grouped_quantizer.optimize_for_gemm = optimize_for_gemm x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = reference_group_quantize( - valid_x, quantizers, split_sections, return_identity, return_transpose + valid_x, quantizers, split_sections, return_rowwise, return_transpose ) # Note: for grouped quantize with paged stashing @@ -281,7 +281,7 @@ def check_grouped_tensor_mxfp8_with_paged_stashing( # get a list of MXFP8 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() - if return_identity: + if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] @@ -356,7 +356,7 @@ def check_grouped_tensor_mxfp8_with_paged_stashing( ], ) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) @pytest.mark.parametrize( "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] @@ -372,14 +372,14 @@ def test_grouped_tensor_mxfp8_versus_reference( split_sections = generate_split_sections(M, N, edge_cases) - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -388,7 +388,7 @@ def test_grouped_tensor_mxfp8_versus_reference( x_dtype=x_dtype, M=M, N=N, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, split_sections=split_sections, optimize_for_gemm=optimize_for_gemm, @@ -423,7 +423,7 @@ def test_grouped_tensor_mxfp8_versus_reference( ], ) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) @pytest.mark.parametrize( "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] @@ -451,14 +451,14 @@ def test_grouped_tensor_mxfp8_with_paged_stashing( else: assert valid_M == M // 2, "valid_M must be M // 2 when edge_cases is not zero_tokens_all" - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -467,7 +467,7 @@ def test_grouped_tensor_mxfp8_with_paged_stashing( x_dtype=x_dtype, M=M, N=N, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, split_sections=split_sections, valid_M=valid_M, diff --git a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py index 94ea699d14..585786a47d 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py +++ b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py @@ -39,7 +39,7 @@ def check_mxfp8_quantize_swizzle_fusion( x_dtype: torch.dtype, M: int, N: int, - return_identity: bool, + return_rowwise: bool, return_transpose: bool, ) -> None: @@ -57,7 +57,7 @@ def check_mxfp8_quantize_swizzle_fusion( # Quantize quantizer = MXFP8Quantizer( fp8_dtype=te_dtype, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, ) @@ -69,7 +69,7 @@ def check_mxfp8_quantize_swizzle_fusion( ) x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = unpack_quantized_tensor(quantizer(x)) - if return_identity: + if return_rowwise: torch.testing.assert_close(x_qx_swf, x_qx_ref, atol=0.0, rtol=0.0) valid_scale_shape = get_mxfp8_scale_shape_no_padding(x.shape, False) assert valid_scale_shape == x_sx_swf.shape, ( @@ -104,7 +104,7 @@ def check_mxfp8_quantize_swizzle_fusion( ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) def test_mxfp8_quantize_swizzle_fusion( x_dtype: torch.dtype, @@ -113,14 +113,14 @@ def test_mxfp8_quantize_swizzle_fusion( quantize_mode: str, ) -> None: - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -129,6 +129,6 @@ def test_mxfp8_quantize_swizzle_fusion( x_dtype=x_dtype, M=M, N=N, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, ) diff --git a/tests/pytorch/nvfp4/nvfp4_utils.py b/tests/pytorch/nvfp4/nvfp4_utils.py index 5f1b5ac36c..757ed249d2 100644 --- a/tests/pytorch/nvfp4/nvfp4_utils.py +++ b/tests/pytorch/nvfp4/nvfp4_utils.py @@ -115,7 +115,7 @@ def reference_group_quantize( x: torch.Tensor, quantizers: list[NVFP4Quantizer], split_sections: list[int], - return_identity: bool, + return_rowwise: bool, return_transpose: bool, ) -> torch.Tensor: x_view = x.reshape(-1, x.size(-1)) @@ -133,7 +133,7 @@ def reference_group_quantize( for i in range(len(x_chunks)): x_chunk = x_chunks[i] x_nvfp4_res = quantizers[i](x_chunk) - if return_identity: + if return_rowwise: x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8)) x_sx.append(x_nvfp4_res._rowwise_scale_inv) x_amax_rowwise.append(x_nvfp4_res._amax_rowwise) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 5f35e9ad10..4074a83ee5 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -37,7 +37,7 @@ def check_group_quantization_nvfp4_versus_reference( x_dtype: torch.dtype, M: int, N: int, - return_identity: bool, + return_rowwise: bool, return_transpose: bool, split_sections: list[int], with_rht: bool = True, @@ -63,7 +63,7 @@ def check_group_quantization_nvfp4_versus_reference( quantizers = [ NVFP4Quantizer( fp4_dtype=te_dtype, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, with_amax_reduction=False, amax_reduction_group=None, @@ -74,12 +74,12 @@ def check_group_quantization_nvfp4_versus_reference( for _ in range(len(split_sections)) ] x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( - reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) + reference_group_quantize(x, quantizers, split_sections, return_rowwise, return_transpose) ) split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers) - if return_identity: + if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] @@ -151,7 +151,7 @@ def check_group_quantization_nvfp4_versus_reference( ], ) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) @pytest.mark.parametrize( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] @@ -172,14 +172,14 @@ def test_rht_with_quantization_block_tiling_versus_reference( # currently disable pre-RHT amax with_post_rht_amax = with_rht - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -188,7 +188,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( x_dtype=x_dtype, M=M, N=N, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, split_sections=split_sections, with_rht=with_rht, diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index 1e62f91eb8..2abca19229 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -46,7 +46,7 @@ def check_grouped_tensor_nvfp4_versus_reference( x_dtype: torch.dtype, M: int, N: int, - return_identity: bool, + return_rowwise: bool, return_transpose: bool, split_sections: list[int], with_rht: bool = True, @@ -75,7 +75,7 @@ def check_grouped_tensor_nvfp4_versus_reference( quantizers = [ NVFP4Quantizer( fp4_dtype=te_dtype, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, with_amax_reduction=False, amax_reduction_group=None, @@ -92,14 +92,14 @@ def check_grouped_tensor_nvfp4_versus_reference( grouped_quantizer.optimize_for_gemm = optimize_for_gemm x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( - reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) + reference_group_quantize(x, quantizers, split_sections, return_rowwise, return_transpose) ) group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) # get a list of nvfp4 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() - if return_identity: + if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] @@ -162,7 +162,7 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( x_dtype: torch.dtype, M: int, N: int, - return_identity: bool, + return_rowwise: bool, return_transpose: bool, split_sections: list[int], with_rht: bool = True, @@ -196,7 +196,7 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( quantizers = [ NVFP4Quantizer( fp4_dtype=te_dtype, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, with_amax_reduction=False, amax_reduction_group=None, @@ -214,7 +214,7 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( reference_group_quantize( - valid_x, quantizers, split_sections, return_identity, return_transpose + valid_x, quantizers, split_sections, return_rowwise, return_transpose ) ) @@ -226,7 +226,7 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( # get a list of nvfp4 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() - if return_identity: + if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] @@ -308,7 +308,7 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( ], ) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) @pytest.mark.parametrize( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] @@ -333,14 +333,14 @@ def test_grouped_tensor_nvfp4_versus_reference( # currently disable pre-RHT amax with_post_rht_amax = with_rht - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -349,7 +349,7 @@ def test_grouped_tensor_nvfp4_versus_reference( x_dtype=x_dtype, M=M, N=N, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, split_sections=split_sections, with_rht=with_rht, @@ -387,7 +387,7 @@ def test_grouped_tensor_nvfp4_versus_reference( ], ) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) @pytest.mark.parametrize( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] @@ -424,14 +424,14 @@ def test_grouped_tensor_nvfp4_with_paged_stashing( # currently disable pre-RHT amax with_post_rht_amax = with_rht - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -440,7 +440,7 @@ def test_grouped_tensor_nvfp4_with_paged_stashing( x_dtype=x_dtype, M=M, N=N, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, split_sections=split_sections, with_rht=with_rht, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 80ccb2f23d..b8a0679624 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -148,7 +148,7 @@ def check_quantization_nvfp4_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] ) @pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"]) @pytest.mark.parametrize( @@ -187,7 +187,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -287,7 +287,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -400,7 +400,7 @@ def test_nvfp4_quantization_boundary_values( ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize( - "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] + "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 5826c4b95f..ae0889feac 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -35,7 +35,7 @@ def check_quantization_nvfp4_versus_reference( M: int, N: int, contiguous: bool, - return_identity: bool, + return_rowwise: bool, return_transpose: bool, use_cpp_allocator: bool, swizzled_scale: bool = False, @@ -62,7 +62,7 @@ def check_quantization_nvfp4_versus_reference( # Quantize nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, with_amax_reduction=False, amax_reduction_group=None, @@ -100,7 +100,7 @@ def check_quantization_nvfp4_versus_reference( # Reference quantization using NVFP4QuantizerRef with built-in RHT ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, - rowwise=return_identity, + rowwise=return_rowwise, columnwise=return_transpose, pow_2_scales=False, eps=0.0, @@ -133,7 +133,7 @@ def check_quantization_nvfp4_versus_reference( sx_t_ref = None ref_amax_colwise_t = None - if return_identity: + if return_rowwise: torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -189,7 +189,7 @@ def check_quantization_nvfp4_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -206,14 +206,14 @@ def test_rht_with_quantization_block_tiling_versus_reference( with_random_sign_mask: bool, ) -> None: - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -223,7 +223,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( M=M, N=N, contiguous=True, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, @@ -239,7 +239,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize( - "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] + "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] ) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -256,14 +256,14 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_random_sign_mask: bool, ): - if quantize_mode == "quantize": - return_identity = True + if quantize_mode == "rowwise_only": + return_rowwise = True return_transpose = False - elif quantize_mode == "quantize_transpose": - return_identity = True + elif quantize_mode == "both_directions": + return_rowwise = True return_transpose = True - elif quantize_mode == "quantize_colwise_only": - return_identity = False + elif quantize_mode == "columnwise_only": + return_rowwise = False return_transpose = True else: raise ValueError(f"Invalid quantize mode: {quantize_mode}") @@ -273,7 +273,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( M=M, N=N, contiguous=False, - return_identity=return_identity, + return_rowwise=return_rowwise, return_transpose=return_transpose, use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, From 43c224016093e36d3026deaa80e17c79cf3dfd3d Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 14:18:08 -0800 Subject: [PATCH 16/22] fix deprecate messsage Signed-off-by: Zhongbo Zhu --- .../common/include/transformer_engine/hadamard_transform.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 75729967a3..c18b3135a7 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -48,8 +48,7 @@ void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int /*! \brief Perform the columnwise hadamard transform cast fusion. * - * This function is experimental and the API is not stable. - * This function will later be deprecated and replaced by nvte_hadamard_transform_cast_fusion + * This has been deprecated in favor of nvte_hadamard_transform_cast_fusion. * * \param[in] input Input tensor to apply Hadamard transform. * \param[in,out] output Output tensor. From e60916903bc05e644603b8b0f928e8a7f97e2e81 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 14:42:47 -0800 Subject: [PATCH 17/22] more compile guard Signed-off-by: Zhongbo Zhu --- ..._group_row_cast_col_hadamard_transform_cast_fusion.cu | 5 +++-- .../group_hadamard_transform_cast_fusion.cu | 9 +++++++++ .../hadamard_transform/hadamard_transform_cast_fusion.cu | 9 +++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 19583b3afb..9222994ca9 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -198,8 +198,8 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g "with architecture-specific compilation. " "Try recompiling with sm_100a or similar."); return; - } - static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, + } else { + static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, "group_row_col_rht_gemm_device_graph_safe must generate row-wise " "and/or column-wise output."); #if !defined(CUTLASS_ARCH_CLC_ENABLED) @@ -1141,6 +1141,7 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g } else { cutlass::arch::warpgroup_reg_dealloc<32>(); } + } } // NOLINT(readability/fn_size) template Date: Tue, 3 Mar 2026 14:43:49 -0800 Subject: [PATCH 18/22] new API name Signed-off-by: Zhongbo Zhu --- .../row_cast_col_hadamard_transform_cast_fusion.cu | 4 ++-- .../common/include/transformer_engine/hadamard_transform.h | 4 ++-- transformer_engine/pytorch/csrc/quantizer.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index 1cd36c4681..ed0b4b089f 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -1378,11 +1378,11 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, } // namespace transformer_engine -void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, +void nvte_quantize_with_hadamard_transform(const NVTETensor input, NVTETensor output, const NVTETensor hadamard_matrix, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - NVTE_API_CALL(nvte_hadamard_transform_cast_fusion); + NVTE_API_CALL(nvte_quantize_with_hadamard_transform); using namespace transformer_engine; QuantizationConfig quant_config_cpp; if (quant_config != nullptr) { diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index c18b3135a7..cf7a1640b9 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -48,7 +48,7 @@ void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int /*! \brief Perform the columnwise hadamard transform cast fusion. * - * This has been deprecated in favor of nvte_hadamard_transform_cast_fusion. + * This has been deprecated in favor of nvte_quantize_with_hadamard_transform. * * \param[in] input Input tensor to apply Hadamard transform. * \param[in,out] output Output tensor. @@ -71,7 +71,7 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE * \param[in] quant_config Quantization configuration. * \param[in] stream CUDA stream used for the operation. */ -void nvte_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor output, +void nvte_quantize_with_hadamard_transform(const NVTETensor input, NVTETensor output, const NVTETensor hadamard_matrix, const NVTEQuantizationConfig quant_config, cudaStream_t stream); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index f79f4d0f2c..1a8069df09 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2264,7 +2264,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 1. Rowwise quantization // 2. RHT followed by columnwise quantization & transpose NVTE_SCOPED_GIL_RELEASE({ - nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(), + nvte_quantize_with_hadamard_transform(input.data(), out.data(), rht_matrix_nvte.data(), quant_config, stream); }); } else { From 3a93a72eda181760a356f538aad55a9832b10d2b Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 15:46:29 -0800 Subject: [PATCH 19/22] fix format all in one Signed-off-by: Zhongbo Zhu --- .../test_mxfp8_group_quantize_graph_safe.py | 8 +- .../test_mxfp8_quantize_swizzle_fusion.py | 4 +- .../nvfp4/test_nvfp4_group_quantize.py | 4 +- .../test_nvfp4_group_quantize_graph_safe.py | 8 +- .../nvfp4/test_nvfp4_quantize_exact.py | 16 +- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 8 +- ...cast_col_hadamard_transform_cast_fusion.cu | 1758 +++++++++-------- .../group_hadamard_transform_cast_fusion.cu | 1010 +++++----- ...cast_col_hadamard_transform_cast_fusion.cu | 6 +- .../transformer_engine/hadamard_transform.h | 6 +- transformer_engine/pytorch/csrc/quantizer.cpp | 2 +- 11 files changed, 1408 insertions(+), 1422 deletions(-) diff --git a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py index 939b6b58b1..c2f8e8de12 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py +++ b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py @@ -355,9 +355,7 @@ def check_grouped_tensor_mxfp8_with_paged_stashing( "random_uneven_split", ], ) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) @pytest.mark.parametrize( "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] ) @@ -422,9 +420,7 @@ def test_grouped_tensor_mxfp8_versus_reference( "random_uneven_split", ], ) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) @pytest.mark.parametrize( "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] ) diff --git a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py index 585786a47d..6f0700809b 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py +++ b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py @@ -103,9 +103,7 @@ def check_mxfp8_quantize_swizzle_fusion( ], ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) def test_mxfp8_quantize_swizzle_fusion( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 4074a83ee5..7fc777f010 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -150,9 +150,7 @@ def check_group_quantization_nvfp4_versus_reference( "random_uneven_split", ], ) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) @pytest.mark.parametrize( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index 2abca19229..64268c512e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -307,9 +307,7 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( "random_uneven_split", ], ) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) @pytest.mark.parametrize( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] ) @@ -386,9 +384,7 @@ def test_grouped_tensor_nvfp4_versus_reference( "random_uneven_split", ], ) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) @pytest.mark.parametrize( "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index b8a0679624..bf3f545b8b 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -147,9 +147,7 @@ def check_quantization_nvfp4_versus_reference( ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] -) +@pytest.mark.parametrize("return_transpose", [True, False], ids=["both_directions", "rowwise_only"]) @pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"]) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] @@ -186,9 +184,7 @@ def test_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] -) +@pytest.mark.parametrize("return_transpose", [True, False], ids=["both_directions", "rowwise_only"]) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @@ -286,9 +282,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] -) +@pytest.mark.parametrize("return_transpose", [True, False], ids=["both_directions", "rowwise_only"]) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @@ -399,9 +393,7 @@ def test_nvfp4_quantization_boundary_values( ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "return_transpose", [True, False], ids=["both_directions", "rowwise_only"] -) +@pytest.mark.parametrize("return_transpose", [True, False], ids=["both_directions", "rowwise_only"]) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index ae0889feac..795721df04 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -188,9 +188,7 @@ def check_quantization_nvfp4_versus_reference( ], ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @@ -238,9 +236,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( ], ) @pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"] -) +@pytest.mark.parametrize("quantize_mode", ["rowwise_only", "both_directions", "columnwise_only"]) @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 9222994ca9..b4d0632422 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -200,947 +200,953 @@ __launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_g return; } else { static_assert(kEnableRHTColQuant_ || kEnableRowQuant_, - "group_row_col_rht_gemm_device_graph_safe must generate row-wise " - "and/or column-wise output."); + "group_row_col_rht_gemm_device_graph_safe must generate row-wise " + "and/or column-wise output."); #if !defined(CUTLASS_ARCH_CLC_ENABLED) - CUTLASS_NOT_IMPLEMENTED(); - return; + CUTLASS_NOT_IMPLEMENTED(); + return; #endif - using X = Underscore; - // Accumulator data type for main computation - using ElementAccumulator = float; - static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); - using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; - static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( - size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); - static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; - static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; - static constexpr bool kEnableRowQuant = kEnableRowQuant_; - static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; - static constexpr bool kUseFastMath = kUseFastMath_; - - // Constant for RHT tensor processing (tile size etc) - static int constexpr RhtTensorSize = 16; - - // Get the total number of tokens to process - // Note that here M is the hidden size, which is the last logical dimension of the input tensor x - // The kernel is designed in column major, so M is the hidden size - size_t sum_token_dims = offsets[num_tensors] / M; - - // Transaction bytes for TMA transfer on RHT tensor blocks - static int constexpr kTmaRhtTensorTransactionBytes = - cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); - static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; - static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; - - // Mainloop pipeline stage calculation, vectorization parameters for scaling factors - static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); - static int constexpr SFVecSize = 16; - // Swizzle output layout for scaling factor arrays - using SwizzledSFALayoutAtom = - cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; - using SwizzledSFDLayoutAtom = - cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; - - // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling - using MainloopPipeline = - cutlass::detail::CustomizedPipelineTmaUmmaAsync; - using MainloopPipelineState = typename MainloopPipeline::PipelineState; - using SchedPipeline = cutlass::PipelineCLCFetchAsync; - using SchedPipelineState = typename SchedPipeline::PipelineState; - using SchedThrottlePipeline = cutlass::PipelineAsync; - using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; - - static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); - - using TmemAllocator = cute::TMEM::Allocator1Sm; - static int constexpr VectorSize = RhtTensorSize; - - // Compile-time safety: static shapes required for shared memory layouts - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - // CUTE_STATIC_ASSERT(is_static::value); - - auto cluster_size = size<0>(cluster_shape); - auto mainloop_tiler = Shape<_128, _16, _128>{}; - auto epilogue_tiler = Shape<_128, _128, _128>{}; - - static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); - - // Get the appropriate blocks for this Cluster - dim3 cluster_coord_in_grid = cluster_id_in_grid(); - - // Total number of k-tiles - int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); - - struct TileScheduler { - uint32_t tiles_in_m = 0; - uint32_t tiles_in_n = 0; - uint32_t linear_idx = 0; - uint32_t next_linear_idx = 0; - uint32_t start_idx = 0; - uint32_t tile_m_idx = 0; - uint32_t tile_n_idx = 0; - int k_tile_max = 0; - uint32_t *atomic_tile_index_; - uint32_t *smem_tile_counter; - uint32_t atomic_offset; - cutlass::FastDivmodU64 divmod_tiles_in_m; - - CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, - uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) - : tiles_in_m(tiles_m), - tiles_in_n(tiles_n), - linear_idx(blockIdx.x), - next_linear_idx(blockIdx.x), - start_idx(blockIdx.x), - k_tile_max(kmax), - atomic_tile_index_(atomic_tile_index), - smem_tile_counter(smem_tile_counter), - atomic_offset(gridDim.x), - divmod_tiles_in_m(uint64_t(tiles_m)) { - update_tile_idx(); - } - CUTLASS_DEVICE void update_tile_idx() { - uint64_t q, r; - divmod_tiles_in_m(q, r, uint64_t(linear_idx)); - tile_m_idx = static_cast(r); - tile_n_idx = static_cast(q) * uint32_t(k_tile_max); - } - CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } - CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } - CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } - - CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } + using X = Underscore; + // Accumulator data type for main computation + using ElementAccumulator = float; + static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_; + static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_; + static constexpr bool kEnableRowQuant = kEnableRowQuant_; + static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_; + static constexpr bool kUseFastMath = kUseFastMath_; + + // Constant for RHT tensor processing (tile size etc) + static int constexpr RhtTensorSize = 16; + + // Get the total number of tokens to process + // Note that here M is the hidden size, which is the last logical dimension of the input tensor x + // The kernel is designed in column major, so M is the hidden size + size_t sum_token_dims = offsets[num_tensors] / M; + + // Transaction bytes for TMA transfer on RHT tensor blocks + static int constexpr kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v); + static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; + static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + + // Mainloop pipeline stage calculation, vectorization parameters for scaling factors + static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{}); + static int constexpr SFVecSize = 16; + // Swizzle output layout for scaling factor arrays + using SwizzledSFALayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + using SwizzledSFDLayoutAtom = + cutlass::detail::Sm1xxBlockScaledOutputConfig::SfAtom; + + // Mainloop pipeline types for TMA async execution and epilogue cluster scheduling + using MainloopPipeline = + cutlass::detail::CustomizedPipelineTmaUmmaAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + using SchedPipeline = cutlass::PipelineCLCFetchAsync; + using SchedPipelineState = typename SchedPipeline::PipelineState; + using SchedThrottlePipeline = cutlass::PipelineAsync; + using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState; + + static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>"); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static int constexpr VectorSize = RhtTensorSize; + + // Compile-time safety: static shapes required for shared memory layouts + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // CUTE_STATIC_ASSERT(is_static::value); + + auto cluster_size = size<0>(cluster_shape); + auto mainloop_tiler = Shape<_128, _16, _128>{}; + auto epilogue_tiler = Shape<_128, _128, _128>{}; + + static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile); + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler); + + struct TileScheduler { + uint32_t tiles_in_m = 0; + uint32_t tiles_in_n = 0; + uint32_t linear_idx = 0; + uint32_t next_linear_idx = 0; + uint32_t start_idx = 0; + uint32_t tile_m_idx = 0; + uint32_t tile_n_idx = 0; + int k_tile_max = 0; + uint32_t *atomic_tile_index_; + uint32_t *smem_tile_counter; + uint32_t atomic_offset; + cutlass::FastDivmodU64 divmod_tiles_in_m; + + CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax, + uint32_t *atomic_tile_index, uint32_t *smem_tile_counter) + : tiles_in_m(tiles_m), + tiles_in_n(tiles_n), + linear_idx(blockIdx.x), + next_linear_idx(blockIdx.x), + start_idx(blockIdx.x), + k_tile_max(kmax), + atomic_tile_index_(atomic_tile_index), + smem_tile_counter(smem_tile_counter), + atomic_offset(gridDim.x), + divmod_tiles_in_m(uint64_t(tiles_m)) { + update_tile_idx(); + } + CUTLASS_DEVICE void update_tile_idx() { + uint64_t q, r; + divmod_tiles_in_m(q, r, uint64_t(linear_idx)); + tile_m_idx = static_cast(r); + tile_n_idx = static_cast(q) * uint32_t(k_tile_max); + } + CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; } + CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; } + CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; } - CUTLASS_DEVICE bool is_valid() const { - return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), - cute::make_coord(tiles_in_m, tiles_in_n)); - } + CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; } - CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } + CUTLASS_DEVICE bool is_valid() const { + return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()), + cute::make_coord(tiles_in_m, tiles_in_n)); + } - CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } + CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; } - // Fetch a new tile_id using atomics. - CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { - uint32_t tile_id_counter = 0; - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - "setp.eq.u32 p, %2, 1;\n\t" - "@p atom.global.add.u32 %0, [%1], 1; \n\t" - "}" - : "=r"(tile_id_counter) - : "l"(atomic_tile_index_), "r"(pred)); + CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; } - return tile_id_counter; - } + // Fetch a new tile_id using atomics. + CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) { + uint32_t tile_id_counter = 0; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p atom.global.add.u32 %0, [%1], 1; \n\t" + "}" + : "=r"(tile_id_counter) + : "l"(atomic_tile_index_), "r"(pred)); - CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, - SchedPipelineState sched_pipeline_consumer_state) { - sched_pipeline.consumer_wait(sched_pipeline_consumer_state); - next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; - cutlass::arch::fence_view_async_shared(); - sched_pipeline.consumer_release(sched_pipeline_consumer_state); - return; - } + return tile_id_counter; + } - CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, - SchedPipelineState sched_pipeline_producer_state) { - uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); - // Wait for clcID buffer to become empty with a flipped phase - sched_pipeline.producer_acquire(sched_pipeline_producer_state); - auto is_leading_thread = cute::elect_one_sync(); - uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; - uint32_t smem_addr = - cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); - if (is_leading_thread) { - cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_consumer_state) { + sched_pipeline.consumer_wait(sched_pipeline_consumer_state); + next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()]; + cutlass::arch::fence_view_async_shared(); + sched_pipeline.consumer_release(sched_pipeline_consumer_state); + return; } - ++sched_pipeline_producer_state; - return sched_pipeline_producer_state; - } + CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline, + SchedPipelineState sched_pipeline_producer_state) { + uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state); + // Wait for clcID buffer to become empty with a flipped phase + sched_pipeline.producer_acquire(sched_pipeline_producer_state); + auto is_leading_thread = cute::elect_one_sync(); + uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset; + uint32_t smem_addr = + cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]); + if (is_leading_thread) { + cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0); + } - CUTLASS_DEVICE auto update_work_tile_info() { - linear_idx = next_linear_idx; - update_tile_idx(); - return; - } - }; - - // Allocate and alias shared memory to the kernel's shared storage type - extern __shared__ char shared_memory[]; - using SharedStorage = - SharedStorage; - SharedStorage &shared_storage = *reinterpret_cast(shared_memory); - - // Compute the number of tiles in M and N after tiling and assign scheduler - uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); - uint32_t tiles_in_n = uint32_t(size(ceil_div(sum_token_dims, size<2>(epilogue_tiler)))); - - TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, - shared_storage.atomic_tile_counter); - - int block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Shapes for accumulated tiles in mainloop and epilogue - auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); - auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); - - // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended - auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); - auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); - - // Number of threads assigned for various epilogue roles depending on quantization settings - static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; - static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; - static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; - static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; - static int constexpr NumSchedThreads = 32; - static int constexpr NumMainloopLoadThreads = 32; - static int constexpr NumEpilogueThreads = - NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; - - TmemAllocator tmem_allocator{}; - cutlass::arch::NamedBarrier tmem_allocation_result_barrier( - NumMmaThreadCount + NumEpilogueColQuantThreadCount, - cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); - - int warp_idx = cutlass::canonical_warp_idx_sync(); - - // warp assignment - bool is_mma_warp = (warp_idx == 0); - bool is_dma_warp = (warp_idx == 1); - bool is_sched_warp = (warp_idx == 2); - bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); - bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); - - typename MainloopPipeline::Params mainloop_pipeline_params; - if (is_dma_warp) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (is_mma_warp) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; - mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; - mainloop_pipeline_params.initializing_warp = 0; - mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + ++sched_pipeline_producer_state; + return sched_pipeline_producer_state; + } - MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, - cluster_shape, cute::true_type{}, // Perform barrier init - cute::true_type{}); // Delay mask calculation + CUTLASS_DEVICE auto update_work_tile_info() { + linear_idx = next_linear_idx; + update_tile_idx(); + return; + } + }; - MainloopPipelineState mainloop_pipe_consumer_state; - MainloopPipelineState mainloop_pipe_producer_state = - cutlass::make_producer_start_state(); + // Allocate and alias shared memory to the kernel's shared storage type + extern __shared__ char shared_memory[]; + using SharedStorage = + SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); - using AccumulatorPipeline = - cutlass::PipelineUmmaAsync; - using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; - using AccumulatorPipelineInitBarriers = cute::bool_constant; + // Compute the number of tiles in M and N after tiling and assign scheduler + uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile)))); + uint32_t tiles_in_n = uint32_t(size(ceil_div(sum_token_dims, size<2>(epilogue_tiler)))); - AccumulatorPipelineState accumulator_pipe_consumer_state; - AccumulatorPipelineState accumulator_pipe_producer_state = - cutlass::make_producer_start_state(); + TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace, + shared_storage.atomic_tile_counter); - typename AccumulatorPipeline::Params accumulator_pipeline_params; - if (is_mma_warp) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; - } - if (is_epilogue_col_quant_warp) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; - } - // Only one producer thread arrives on this barrier. - accumulator_pipeline_params.producer_arv_count = 1; - accumulator_pipeline_params.consumer_arv_count = - size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; - accumulator_pipeline_params.initializing_warp = 1; - AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, - cluster_shape, AccumulatorPipelineInitBarriers{}, - cute::true_type{}); // Delay mask calculation - typename SchedPipeline::Params sched_pipeline_params; - if (is_sched_warp) { - sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; - } else { - sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; - } - sched_pipeline_params.producer_blockid = 0; - sched_pipeline_params.producer_arv_count = 1; - sched_pipeline_params.consumer_arv_count = - NumSchedThreads + - cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); - sched_pipeline_params.transaction_bytes = sizeof(uint32_t); - sched_pipeline_params.initializing_warp = 3; - SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); - SchedPipelineState sched_pipeline_consumer_state; - SchedPipelineState sched_pipeline_producer_state = - cutlass::make_producer_start_state(); - - typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; - if (is_dma_warp) { - sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; - } - if (is_sched_warp) { - sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; - } - sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; - sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; - sched_throttle_pipeline_params.dst_blockid = 0; - sched_throttle_pipeline_params.initializing_warp = 4; - - SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, - sched_throttle_pipeline_params); - SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; - SchedThrottlePipelineState sched_pipeline_throttle_producer_state = - cutlass::make_producer_start_state(); - - if (warp_idx == 2 && elect_one_sync()) { - cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); - } - __syncthreads(); - - // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer - if (is_dma_warp) { - // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). - cutlass::arch::warpgroup_reg_dealloc<32>(); - // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. - Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); - Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); - - // Partition tensors for tiling according to the mainloop and cluster tilers. - Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); - Tensor gB_nk = - local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) - - // Shared memory tensors for pipeline - Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), - sAlayout); // (MMA,MMA_M,MMA_N,PIPE) - Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), - sBlayout); // (MMA,MMA_N,MMA_K,PIPE) - - // Determine warp/tile positioning int block_rank_in_cluster = cute::block_rank_in_cluster(); - ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx - // Partition global to local fragments for A and B - Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) - Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) - - Layout cta_layout_mnk = make_layout(cluster_shape); - Layout cta_layout_vmnk = - tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); - - auto [tAgA, tAsA] = - tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), - group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); - - auto [tBgB, tBsB] = - tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), - group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); - - uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); - uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - if constexpr (kEnableRHTColQuant) { - if (elect_one_sync()) { - cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], - kTmaRhtTensorTransactionBytes); - copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), - tBsB(_, 0)); - } - } - do { - // is_first_wave indicates whether this scheduler wave is the first among a group. - bool is_first_wave = scheduler.is_first_wave(); - uint32_t skip_wait = is_first_wave; - auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); - int k_tile = 0; - - sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); - sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); - ++sched_pipeline_throttle_producer_state; - CUTLASS_PRAGMA_NO_UNROLL - while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { - int k_tile_idx_n = scheduler.tile_n_base() + k_tile; - ++k_tile; - skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); - mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType *tma_barrier = - mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); - int write_stage = mainloop_pipe_producer_state.index(); - ++mainloop_pipe_producer_state; - if (cute::elect_one_sync()) { - copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), - tAsA(_, write_stage)); - } - } - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; - scheduler.update_work_tile_info(); - // scheduler.advance(); - } while (scheduler.is_valid()); - mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); - } else if (is_mma_warp) { - // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. - cutlass::arch::warpgroup_reg_dealloc<32>(); - if constexpr (kEnableRHTColQuant) { - // Setup shared memory fragments for A and B tiles. + // Shapes for accumulated tiles in mainloop and epilogue + auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{}); + auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{}); + + // Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended + auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int{}); + auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape); + + // Number of threads assigned for various epilogue roles depending on quantization settings + static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0; + static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0; + static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0; + static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0; + static int constexpr NumSchedThreads = 32; + static int constexpr NumMainloopLoadThreads = 32; + static int constexpr NumEpilogueThreads = + NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount; + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + NumMmaThreadCount + NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // warp assignment + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_sched_warp = (warp_idx == 2); + bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7); + bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15); + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + mainloop_pipeline_params.num_consumers = + NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount; + + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + using AccumulatorPipelineInitBarriers = cute::bool_constant; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_col_quant_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = + size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline( + shared_storage.accumulator, accumulator_pipeline_params, cluster_shape, + AccumulatorPipelineInitBarriers{}, cute::true_type{}); // Delay mask calculation + typename SchedPipeline::Params sched_pipeline_params; + if (is_sched_warp) { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer; + } else { + sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer; + } + sched_pipeline_params.producer_blockid = 0; + sched_pipeline_params.producer_arv_count = 1; + sched_pipeline_params.consumer_arv_count = + NumSchedThreads + + cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount); + sched_pipeline_params.transaction_bytes = sizeof(uint32_t); + sched_pipeline_params.initializing_warp = 3; + SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape); + SchedPipelineState sched_pipeline_consumer_state; + SchedPipelineState sched_pipeline_producer_state = + cutlass::make_producer_start_state(); + + typename SchedThrottlePipeline::Params sched_throttle_pipeline_params; + if (is_dma_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer; + } + if (is_sched_warp) { + sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer; + } + sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + sched_throttle_pipeline_params.dst_blockid = 0; + sched_throttle_pipeline_params.initializing_warp = 4; + + SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle, + sched_throttle_pipeline_params); + SchedThrottlePipelineState sched_pipeline_throttle_consumer_state; + SchedThrottlePipelineState sched_pipeline_throttle_producer_state = + cutlass::make_producer_start_state(); + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + + // Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer + if (is_dma_warp) { + // Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access). + cutlass::arch::warpgroup_reg_dealloc<32>(); + // Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory. + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize)); + + // Partition tensors for tiling according to the mainloop and cluster tilers. + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + + // Shared memory tensors for pipeline Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + // Determine warp/tile positioning int block_rank_in_cluster = cute::block_rank_in_cluster(); ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx - // Allocate "fragments" -- these are actually umma smem descriptors - Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) - - mma.accumulate_ = UMMA::ScaleOut::Zero; - - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, - &shared_storage.tmem_base_ptr); - __syncwarp(); - tmem_allocation_result_barrier.arrive(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - bulk_tmem_mma.data() = tmem_base_ptr; - // Wait until the B (Hadamard) tensor copy is complete - cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); - do { - uint32_t skip_wait = K_TILE_MAX <= 0; + // Partition global to local fragments for A and B + Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k) + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = + tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + if constexpr (kEnableRHTColQuant) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } + } - auto barrier_token = - mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; + do { + // is_first_wave indicates whether this scheduler wave is the first among a group. + bool is_first_wave = scheduler.is_first_wave(); + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, scheduler.tile_m(), _); + int k_tile = 0; + + sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state); + sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state); + ++sched_pipeline_throttle_producer_state; CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { - mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - int read_stage = mainloop_pipe_consumer_state.index(); - auto tCrA_mk = tCrA(_, _, _, read_stage); - auto tCrB_nk = tCrB(_, _, 0, 0); - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { - int accumulator_k_block = - accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; - int tCrA_k_block = k_block * EpilogueUnrollFactor; - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < EpilogueUnrollFactor; i++) { - auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); - gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); - } - - accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); - ++accumulator_pipe_producer_state; - } - auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; - ++mainloop_pipe_consumer_state; + while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) { + int k_tile_idx_n = scheduler.tile_n_base() + k_tile; ++k_tile; - skip_wait = k_tile >= K_TILE_MAX; - mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); - barrier_token = - mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } } + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; scheduler.update_work_tile_info(); + // scheduler.advance(); } while (scheduler.is_valid()); - tmem_allocator.release_allocation_lock(); - accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); - tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); - } - } else if (is_sched_warp) { - // Scheduler warp manages tile assignment and pipeline progress for warps - cutlass::arch::warpgroup_reg_dealloc<32>(); - do { - sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); - sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); - ++sched_pipeline_throttle_consumer_state; - sched_pipeline_producer_state = - scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; - scheduler.update_work_tile_info(); - } while (scheduler.is_valid()); - } else if (is_epilogue_col_quant_warp) { - // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, - // and writing result tensors/scales to global memory. - cutlass::arch::warpgroup_reg_alloc<192>(); - if constexpr (kEnableRHTColQuant) { - using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; - - auto acc_epilogue_pipelined_shape = - append(acc_shape_epilogue, Int{}); - auto bulk_tmem_epilogue_layout = make_layout( - acc_epilogue_pipelined_shape, - make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); - auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); - - // Use 256-bit fragments for aligned bulk stores - static int constexpr FragmentSize = 256 / sizeof_bits_v; - - // Wait for TMEM allocation for this pipeline to finish - tmem_allocation_result_barrier.arrive_and_wait(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - bulk_tmem_epilogue.data() = tmem_base_ptr; - int global_thread_idx = threadIdx.x; - int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; - // g2s load all global_d_amax - CUTLASS_PRAGMA_NO_UNROLL - for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueColQuantThreadCount) { - shared_storage.global_d_amax[g] = __ldg(reinterpret_cast(amax_colwise + g)); - } - - size_t rng_seed = 0; - size_t rng_offset = 0; - // Setup RNG for stochastic rounding - if constexpr (kEnableStochasticRounding) { - rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; - rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; - } - // TODO(zhongbo): double check the logic here - int group_idx = get_current_tensor_id(shape_rep, num_tensors, - (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, - packed_N, M, offsets); - - // Determine quantization scale factor layouts/output splits for this group - TSFDLayout sfd_layout; - int cur_N = static_cast(first_dims[group_idx]); - if constexpr (kEnableSwizzleSFOutput) { - sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); - } else { - sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), - make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); - } - // Build output tensors for columns and their quant scales - // TODO(zhongbo): double check the logic here - Tensor mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( - reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), - make_shape(M, cur_N), DStride{}); // (M,packed_N) - Tensor gD_mn = - local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N) - - // for every tensor [x, y] row major, x y both a multiple of 128 - // both of its rowwise and colwise scaling factors will have exactly x * y / 16 elements in FP8 E4M3 - Tensor mSFD = make_tensor( - make_gmem_ptr(reinterpret_cast(reinterpret_cast(SFA_COLWISE) + - offsets[group_idx] / kNVFP4BlockSize)), - sfd_layout); - Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), - Step<_1, _1, X>{}); // (BLK_M,BLK_N) - - Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); - - // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors - auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); - auto tiled_r2g = - make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); - auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); - - cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} - static constexpr float fp4_max = 6.0f; - static constexpr float fp8_max = 448.0f; - static constexpr float fp4_max_inv = 1.0f / fp4_max; - float c_global_amax_val = shared_storage.global_d_amax[group_idx]; - float global_encode_scale = c_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / c_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - float global_decode_scale = 1.0f / global_encode_scale; - - // Scaling factor for fast math path - float global_encode_scale_multiplier = 1.0f; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + // This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform. + cutlass::arch::warpgroup_reg_dealloc<32>(); + if constexpr (kEnableRHTColQuant) { + // Setup shared memory fragments for A and B tiles. + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + // Wait until the B (Hadamard) tensor copy is complete + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) { + int accumulator_k_block = + accumulator_pipe_producer_state.index() * EpilogueUnrollFactor; + int tCrA_k_block = k_block * EpilogueUnrollFactor; + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < EpilogueUnrollFactor; i++) { + auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i); + gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state); + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + } + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); } - + } else if (is_sched_warp) { + // Scheduler warp manages tile assignment and pipeline progress for warps + cutlass::arch::warpgroup_reg_dealloc<32>(); do { + sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state); + sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state); + ++sched_pipeline_throttle_consumer_state; + sched_pipeline_producer_state = + scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state); scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } else if (is_epilogue_col_quant_warp) { + // Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage, + // and writing result tensors/scales to global memory. + cutlass::arch::warpgroup_reg_alloc<192>(); + if constexpr (kEnableRHTColQuant) { + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + auto acc_epilogue_pipelined_shape = + append(acc_shape_epilogue, Int{}); + auto bulk_tmem_epilogue_layout = make_layout( + acc_epilogue_pipelined_shape, + make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler))); + auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr(), bulk_tmem_epilogue_layout); + + // Use 256-bit fragments for aligned bulk stores + static int constexpr FragmentSize = 256 / sizeof_bits_v; + + // Wait for TMEM allocation for this pipeline to finish + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup; + // g2s load all global_d_amax CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); - ++k_tile) { - int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); - - // TODO(zhongbo): double check the logic here - int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, - global_tile_n_offset * M, packed_N, M, offsets); - - if (cur_group_idx != group_idx) { - group_idx = cur_group_idx; - c_global_amax_val = shared_storage.global_d_amax[group_idx]; - // update amax - global_encode_scale = c_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / c_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - global_decode_scale = 1.0f / global_encode_scale; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; - } - // TODO(zhongbo): double check the logic here - cur_N = first_dims[group_idx]; - if constexpr (kEnableSwizzleSFOutput) { - sfd_layout = - tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); - } else { - sfd_layout = - make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), - make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); - } - // update tensor - mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( - reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), - make_shape(M, cur_N), DStride{}); - gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), - Step<_1, _1, X>{}); // (BLK_M,BLK_N) - mSFD = make_tensor( - make_gmem_ptr(reinterpret_cast(reinterpret_cast(SFA_COLWISE) + - offsets[group_idx] / kNVFP4BlockSize)), - sfd_layout); - gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), - Step<_1, _1, X>{}); // (BLK_M,BLK_N) + for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueColQuantThreadCount) { + shared_storage.global_d_amax[g] = __ldg(reinterpret_cast(amax_colwise + g)); + } - gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); - } - int group_start_offset = offsets[group_idx] / M; - int local_tile_n_idx = - (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); - Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); - - Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); - accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); - - auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); - Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - - Tensor tTR_rAcc = - make_tensor(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDrD = make_tensor(shape(tDgD)); - Tensor tTR_rAcc_frag = - recast>(coalesce(tTR_rAcc)); - Tensor tDrD_frag = recast>(coalesce(tDrD)); - - Tensor src = thr_r2g.retile_S(tDrD); - Tensor dst = thr_r2g.retile_D(tDgD); - - Tensor tDgSFD_view = make_tensor( - tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), - make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); - Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); - Tensor tDrSFD = make_tensor(shape(tDgSFD)); - - static int constexpr NumVecs = size(tDgD) / VectorSize; - Tensor tD_rRowSFD_frg = recast>(tDrSFD); - - // Compute amax and quantization scales for this tile - cutlass::maximum_absolute_value_reduction, - true> - amax_reduction; - cutlass::Array vec_maxs; - cutlass::Array pvscales; - // Copy from TMEM to registers - copy(tiled_t2r, tDtAcc, tTR_rAcc); - cutlass::arch::fence_view_async_tmem_load(); - accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); - ++accumulator_pipe_consumer_state; - - if constexpr (!kUseFastMath) { - // Downcast to BF16 for bit-wise compatibility with - // unfused kernels - auto convert_accum_to_bf16 = - cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = - cutlass::NumericArrayConverter{}; - tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); - tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); - } + size_t rng_seed = 0; + size_t rng_offset = 0; + // Setup RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // TODO(zhongbo): double check the logic here + int group_idx = get_current_tensor_id( + shape_rep, num_tensors, (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, + packed_N, M, offsets); + + // Determine quantization scale factor layouts/output splits for this group + TSFDLayout sfd_layout; + int cur_N = static_cast(first_dims[group_idx]); + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // Build output tensors for columns and their quant scales + // TODO(zhongbo): double check the logic here + Tensor mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( + reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), + make_shape(M, cur_N), DStride{}); // (M,packed_N) + Tensor gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) - auto compute_frgs = reinterpret_cast *>( - tTR_rAcc_frag.data()); - auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < NumVecs; v++) { - vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); - } + // for every tensor [x, y] row major, x y both a multiple of 128 + // both of its rowwise and colwise scaling factors will have exactly x * y / 16 elements in FP8 E4M3 + Tensor mSFD = make_tensor( + make_gmem_ptr(reinterpret_cast(reinterpret_cast(SFA_COLWISE) + + offsets[group_idx] / kNVFP4BlockSize)), + sfd_layout); + Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); + + // Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(local_thread_idx); + auto thr_r2g = tiled_r2g.get_slice(local_thread_idx); + + cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float c_global_amax_val = shared_storage.global_d_amax[group_idx]; + float global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + float global_decode_scale = 1.0f / global_encode_scale; + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } - if constexpr (kUseFastMath) { - // Fast math: multiply with precomputed reciprocal - pvscales = cutlass::multiplies>{}( - vec_maxs, global_encode_scale_multiplier); - } else { - // Accurate math: perform division - pvscales = - cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}( - pvscales, global_encode_scale); - } - auto pvscales_cvted = - cutlass::NumericArrayConverter{}(pvscales); - - tD_rRowSFD_frg(_0{}) = pvscales_cvted; - auto qpvscale_ups = cutlass::NumericArrayConverter{}( - tD_rRowSFD_frg(_0{})); - auto qpvscale_scaled = cutlass::multiplies>{}( - qpvscale_ups, global_decode_scale); - cutlass::Array acc_scales; - if constexpr (kUseFastMath) { - // Fast math: compute approximate reciprocal - acc_scales = - cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); - } else { - // Accurate math: compute reciprocal with division - acc_scales = cutlass::divides>{}( - 1.0, qpvscale_scaled); - } + do { + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n(); + ++k_tile) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); - // Prepare stochastic rounding random state if enabled - uint4 random_uint4 = uint4{0, 0, 0, 0}; - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - // "Prefetch" a stochastic rounding state for the first tile - if constexpr (kEnableStochasticRounding) { - const size_t rng_sequence = global_thread_idx + k_tile * 512 + - scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; - rng.init(rng_seed, rng_sequence, rng_offset); - } - CUTLASS_PRAGMA_UNROLL - // Apply round/quantize to each fragment, with or without stochastic rounding - for (int v = 0; v < NumVecs; v++) { - auto acc_scale = cutlass::minimum_with_nan_propagation{}( - acc_scales[v], cutlass::platform::numeric_limits::max()); - if constexpr (kEnableStochasticRounding) { - random_uint4 = rng.generate4(); - output_frgs[v] = StochasticNumericConverter( - cutlass::multiplies>{}( - compute_frgs[v], acc_scale), - *reinterpret_cast *>(&random_uint4)); - } else { - output_frgs[v] = cutlass::NumericArrayConverter{}( - cutlass::multiplies>{}( - compute_frgs[v], acc_scale)); + // TODO(zhongbo): double check the logic here + int cur_group_idx = get_current_tensor_id( + shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets); + + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + c_global_amax_val = shared_storage.global_d_amax[group_idx]; + // update amax + global_encode_scale = c_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / c_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + // TODO(zhongbo): double check the logic here + cur_N = first_dims[group_idx]; + if constexpr (kEnableSwizzleSFOutput) { + sfd_layout = + tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); + } else { + sfd_layout = + make_layout(make_shape(M, make_shape(Int{}, cur_N / SFVecSize)), + make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); + } + // update tensor + mD = make_tensor(cute::subbyte_iterator(reinterpret_cast( + reinterpret_cast(QA_COLWISE) + offsets[group_idx] / 2)), + make_shape(M, cur_N), DStride{}); + gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + mSFD = make_tensor(make_gmem_ptr(reinterpret_cast( + reinterpret_cast(SFA_COLWISE) + + offsets[group_idx] / kNVFP4BlockSize)), + sfd_layout); + gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N) + + gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler)); } - } - - // Write quantized FP4 tile and dequant scale to gmem - copy(tiled_r2g, src, dst); - copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); - } - scheduler.update_work_tile_info(); - } while (scheduler.is_valid()); - } - } else if (is_epilogue_row_quant_warp) { - // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. - cutlass::arch::warpgroup_reg_alloc<136>(); - if constexpr (kEnableRowQuant) { - using S2RVectorType = uint128_t; - - int global_thread_idx = threadIdx.x; - int local_thread_idx = global_thread_idx % 256; - size_t rng_seed = 0; - size_t rng_offset = 0; - // g2s load all global_a_amax for all groups/tensors - CUTLASS_PRAGMA_NO_UNROLL - for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueRowQuantThreadCount) { - shared_storage.global_a_amax[g] = __ldg(reinterpret_cast(amax_rowwise + g)); - } - // RNG for stochastic rounding - if constexpr (kEnableStochasticRounding) { - rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; - rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; - } - // Input/output tensors/partitions for row quant warp - Tensor mQA = - make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); - Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); - Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); - - Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_N) - // Swizzled shared memory A tile, with layout - Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( - coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), - sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) - - // Set up layouts for partitioning – tile-by-warp, with vector granularity - using S2RWarpLayout = Layout>; - using WarpGroupLayout = Layout>; - using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); - using S2RValLayout = Layout, _1>>; - using S2RAtomA = Copy_Atom; - using R2GAtomQA = Copy_Atom; - using R2GAtomSFA = Copy_Atom; - auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); - auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); - auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); - - auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); - auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); - auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); - Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) - - // Allocate temporary register tensors for copying quantization => output - Tensor tQArA = make_tensor_like( - make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) - Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); - Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); - - Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); - Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); - - // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 - // in order to go over the reserved named barrier count. - constexpr int row_quant_barrier_id = 2; - cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); - - int group_idx = get_current_tensor_id(shape_rep, num_tensors, - (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, - packed_N, M, offsets); - float a_global_amax_val = shared_storage.global_a_amax[group_idx]; - // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} - static constexpr float fp4_max = 6.0f; - static constexpr float fp8_max = 448.0f; - static constexpr float fp4_max_inv = 1.0f / fp4_max; - float global_encode_scale = a_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / a_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - - float global_decode_scale = 1.0f / global_encode_scale; - float global_encode_scale_multiplier = 1.0f; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; - } - auto sfa_converter = cutlass::NumericConverter{}; - do { - CUTLASS_PRAGMA_NO_UNROLL - for (int k_tile = 0; - k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { - int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); - - int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, - global_tile_n_offset * M, packed_N, M, offsets); - if (cur_group_idx != group_idx) { - group_idx = cur_group_idx; - a_global_amax_val = shared_storage.global_a_amax[group_idx]; - // Update group quantization parameters/scaling - global_encode_scale = a_global_amax_val > 0.0f - ? cutlass::minimum_with_nan_propagation{}( - (fp8_max * fp4_max) / a_global_amax_val, - cutlass::platform::numeric_limits::max()) - : 1.0f; - global_decode_scale = 1.0f / global_encode_scale; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + int group_start_offset = offsets[group_idx] / M; + int local_tile_n_idx = + (global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler); + Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx); + + Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = make_tensor( + shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrD = make_tensor(shape(tDgD)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrD_frag = recast>(coalesce(tDrD)); + + Tensor src = thr_r2g.retile_S(tDrD); + Tensor dst = thr_r2g.retile_D(tDgD); + + Tensor tDgSFD_view = make_tensor( + tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{}))); + Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view)); + Tensor tDrSFD = make_tensor(shape(tDgSFD)); + + static int constexpr NumVecs = size(tDgD) / VectorSize; + Tensor tD_rRowSFD_frg = recast>(tDrSFD); + + // Compute amax and quantization scales for this tile + cutlass::maximum_absolute_value_reduction< + cutlass::Array, true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // Copy from TMEM to registers + copy(tiled_t2r, tDtAcc, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with + // unfused kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = + convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + tTR_rAcc_frag(_1{}) = + convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{}))); } - } - auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); - auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); - auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); - mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); - cutlass::arch::fence_view_async_shared(); - mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); - ++mainloop_pipe_consumer_state; - ++k_tile; + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrD_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } - // static int constexpr NumVecs = size(tQArA) / VectorSize; - cutlass::maximum_absolute_value_reduction, - true> - amax_reduction; - auto compute_frgs = reinterpret_cast *>(tQArA.data()); - auto output_frgs = - reinterpret_cast *>(raw_pointer_cast(tQArQA.data())); - Tensor amax = - make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); - Tensor pvscales = make_tensor_like(amax); - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - if constexpr (kEnableStochasticRounding) { - const size_t rng_sequence = global_thread_idx + k_tile * 512 + - scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + - tiles_in_m * tiles_in_n * K_TILE_MAX * 512; - rng.init(rng_seed, rng_sequence, rng_offset); - } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { - auto amax_view = group_modes<1, rank(amax)>(amax); - auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); - auto compute_frgs_up = - cutlass::NumericArrayConverter{}( - compute_frgs[v]); - amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); if constexpr (kUseFastMath) { // Fast math: multiply with precomputed reciprocal - pvscales_view(_0{}, v) = cutlass::multiplies{}( - amax_view(_0{}, v), global_encode_scale_multiplier); + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); } else { // Accurate math: perform division - pvscales_view(_0{}, v) = - cutlass::divides{}(amax_view(_0{}, v), fp4_max); - pvscales_view(_0{}, v) = cutlass::multiplies{}( - pvscales_view(_0{}, v), global_encode_scale); + pvscales = cutlass::divides>{}(vec_maxs, + fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); } - filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); - auto qpvscale_ups = - cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tD_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tD_rRowSFD_frg(_0{})); auto qpvscale_scaled = - cutlass::multiplies{}(qpvscale_ups, global_decode_scale); - ElementAccumulator acc_scales; + cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; if constexpr (kUseFastMath) { // Fast math: compute approximate reciprocal acc_scales = cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); } else { // Accurate math: compute reciprocal with division - acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + acc_scales = cutlass::divides>{}( + 1.0, qpvscale_scaled); } - auto acc_scale = cutlass::minimum_with_nan_propagation{}( - acc_scales, cutlass::platform::numeric_limits::max()); + + // Prepare stochastic rounding random state if enabled uint4 random_uint4 = uint4{0, 0, 0, 0}; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + // "Prefetch" a stochastic rounding state for the first tile if constexpr (kEnableStochasticRounding) { - random_uint4 = rng.generate4(); - output_frgs[v] = StochasticNumericConverter( - cutlass::multiplies>{}( - compute_frgs_up, acc_scale), - *reinterpret_cast *>(&random_uint4)); - } else { - output_frgs[v] = - cutlass::NumericArrayConverter{}( - cutlass::multiplies>{}( - compute_frgs_up, acc_scale)); + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + // Apply round/quantize to each fragment, with or without stochastic rounding + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } } + + // Write quantized FP4 tile and dequant scale to gmem + copy(tiled_r2g, src, dst); + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD); } - copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); - copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + } else if (is_epilogue_row_quant_warp) { + // Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage. + cutlass::arch::warpgroup_reg_alloc<136>(); + if constexpr (kEnableRowQuant) { + using S2RVectorType = uint128_t; + + int global_thread_idx = threadIdx.x; + int local_thread_idx = global_thread_idx % 256; + size_t rng_seed = 0; + size_t rng_offset = 0; + // g2s load all global_a_amax for all groups/tensors + CUTLASS_PRAGMA_NO_UNROLL + for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueRowQuantThreadCount) { + shared_storage.global_a_amax[g] = __ldg(reinterpret_cast(amax_rowwise + g)); } - // scheduler.advance(); - scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); - ++sched_pipeline_consumer_state; - scheduler.update_work_tile_info(); - } while (scheduler.is_valid()); - } + // RNG for stochastic rounding + if constexpr (kEnableStochasticRounding) { + rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0; + rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0; + } + // Input/output tensors/partitions for row quant warp + Tensor mQA = + make_tensor(cute::subbyte_iterator(QA), make_layout(make_shape(M, packed_N), dQA)); + Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout); + + Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_N) + // Swizzled shared memory A tile, with layout + Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>( + coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE) + + // Set up layouts for partitioning – tile-by-warp, with vector granularity + using S2RWarpLayout = Layout>; + using WarpGroupLayout = Layout>; + using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{})); + using S2RValLayout = Layout, _1>>; + using S2RAtomA = Copy_Atom; + using R2GAtomQA = Copy_Atom; + using R2GAtomSFA = Copy_Atom; + auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{}); + auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{}); + + auto thr_s2r = tiled_s2r.get_slice(local_thread_idx); + auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx); + auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx); + Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE) + + // Allocate temporary register tensors for copying quantization => output + Tensor tQArA = make_tensor_like( + make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N) + Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn); + Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{})); + + Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn); + Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{})); + + // Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8 + // in order to go over the reserved named barrier count. + constexpr int row_quant_barrier_id = 2; + cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id); + + int group_idx = get_current_tensor_id( + shape_rep, num_tensors, (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, + packed_N, M, offsets); + float a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release} + static constexpr float fp4_max = 6.0f; + static constexpr float fp8_max = 448.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + float global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + + float global_decode_scale = 1.0f / global_encode_scale; + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + auto sfa_converter = cutlass::NumericConverter{}; + do { + CUTLASS_PRAGMA_NO_UNROLL + for (int k_tile = 0; + k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) { + int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler); + + int cur_group_idx = get_current_tensor_id( + shape_rep, num_tensors, global_tile_n_offset * M, packed_N, M, offsets); + if (cur_group_idx != group_idx) { + group_idx = cur_group_idx; + a_global_amax_val = shared_storage.global_a_amax[group_idx]; + // Update group quantization parameters/scaling + global_encode_scale = a_global_amax_val > 0.0f + ? cutlass::minimum_with_nan_propagation{}( + (fp8_max * fp4_max) / a_global_amax_val, + cutlass::platform::numeric_limits::max()) + : 1.0f; + global_decode_scale = 1.0f / global_encode_scale; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + } - } else { - cutlass::arch::warpgroup_reg_dealloc<32>(); - } + auto tQAgSFA_mn = + tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile); + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state); + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA); + cutlass::arch::fence_view_async_shared(); + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + ++mainloop_pipe_consumer_state; + ++k_tile; + + // static int constexpr NumVecs = size(tQArA) / VectorSize; + cutlass::maximum_absolute_value_reduction< + cutlass::Array, true> + amax_reduction; + auto compute_frgs = reinterpret_cast *>(tQArA.data()); + auto output_frgs = reinterpret_cast *>( + raw_pointer_cast(tQArQA.data())); + Tensor amax = + make_tensor(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{})); + Tensor pvscales = make_tensor_like(amax); + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + if constexpr (kEnableStochasticRounding) { + const size_t rng_sequence = global_thread_idx + k_tile * 512 + + scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 + + tiles_in_m * tiles_in_n * K_TILE_MAX * 512; + rng.init(rng_seed, rng_sequence, rng_offset); + } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) { + auto amax_view = group_modes<1, rank(amax)>(amax); + auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales); + auto compute_frgs_up = + cutlass::NumericArrayConverter{}( + compute_frgs[v]); + amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales_view(_0{}, v) = cutlass::multiplies{}( + amax_view(_0{}, v), global_encode_scale_multiplier); + } else { + // Accurate math: perform division + pvscales_view(_0{}, v) = + cutlass::divides{}(amax_view(_0{}, v), fp4_max); + pvscales_view(_0{}, v) = cutlass::multiplies{}( + pvscales_view(_0{}, v), global_encode_scale); + } + filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v)); + auto qpvscale_ups = + cutlass::NumericConverter{}(filter(tQArSFA)(v)); + auto qpvscale_scaled = + cutlass::multiplies{}(qpvscale_ups, global_decode_scale); + ElementAccumulator acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = cutlass::reciprocal_approximate_ftz{}( + qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides{}(1.0, qpvscale_scaled); + } + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales, cutlass::platform::numeric_limits::max()); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale), + *reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = + cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs_up, acc_scale)); + } + } + copy(tiled_r2g_QA, tQArQA, tQAgQA_mn); + copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn)); + } + // scheduler.advance(); + scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state); + ++sched_pipeline_consumer_state; + scheduler.update_work_tile_info(); + } while (scheduler.is_valid()); + } + + } else { + cutlass::arch::warpgroup_reg_dealloc<32>(); + } } } // NOLINT(readability/fn_size) diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu index b1ed7ce6fd..2181564c09 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu @@ -179,528 +179,532 @@ __global__ static void group_rht_gemm_device( "Try recompiling with sm_100a or similar."); return; } else { - using X = Underscore; - // static constexpr bool kApplyStochasticRounding = true; - using ElementAccumulator = float; - static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); - using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; - static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes( - size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); - - static constexpr int kTmaRhtTensorTransactionBytes = - cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); - static constexpr int AccumulatorPipelineStageCount = 16; - - static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); - using MainloopPipeline = - cutlass::PipelineTmaUmmaAsync, AtomThrShapeMNK>; - using MainloopPipelineState = typename MainloopPipeline::PipelineState; - - using TmemAllocator = cute::TMEM::Allocator1Sm; - static constexpr int VectorSize = 16; - const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; - const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - // Preconditions - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - - // Represent the full tensors - Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, N)); - Tensor mB = tma_load_b.get_tma_tensor(make_shape(16, 16)); - - using TensorC = decltype(make_tensor(subbyte_iterator(recast_ptr(nullptr)), // engine - make_shape(int{}, int{}), // (M, N_i) - Stride2D{} // stride (dM, dN) - )); - - using TensorSFC = decltype(make_tensor( - make_gmem_ptr(recast_ptr(nullptr)), - make_layout(make_shape(int{}, // M - make_shape(make_shape(Int<16>{}, _4{}), // (16, 4) - int{}) // n_tiles = split / 64 - ), - make_stride(int{}, // dM = (split / 16) - make_stride(make_stride(_0{}, _1{}), // inner (16,4) layout - _4{}) // tiles stride - )))); - - auto cluster_shape = Shape<_1, _1, _1>{}; - - // Get the appropriate blocks for this Cluster - dim3 cluster_coord_in_grid = cluster_id_in_grid(); - - // Total number of k-tiles - const int K_TILE_MAX = min(N, K) / 64; - uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); - uint32_t tiles_in_n = (N + 64 - 1) / 64; - uint32_t linear_tile_idx = blockIdx.x; - uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; - uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; - - auto mainloop_tiler = Shape<_128, _16, _64>{}; - auto epilogue_tiler = Shape<_128, _64, _64>{}; - Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); - Tensor gB_nk = - local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) - // Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) - - using TensorGC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, - make_coord(_, _, _), Step<_1, _1, X>{})); - - using TensorGSFC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes( + size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v); + + static constexpr int kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + static constexpr int AccumulatorPipelineStageCount = 16; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync, AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static constexpr int VectorSize = 16; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // Represent the full tensors + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16, 16)); + + using TensorC = decltype(make_tensor(subbyte_iterator(recast_ptr(nullptr)), // engine + make_shape(int{}, int{}), // (M, N_i) + Stride2D{} // stride (dM, dN) + )); + + using TensorSFC = decltype(make_tensor( + make_gmem_ptr(recast_ptr(nullptr)), + make_layout(make_shape(int{}, // M + make_shape(make_shape(Int<16>{}, _4{}), // (16, 4) + int{}) // n_tiles = split / 64 + ), + make_stride(int{}, // dM = (split / 16) + make_stride(make_stride(_0{}, _1{}), // inner (16,4) layout + _4{}) // tiles stride + )))); + + auto cluster_shape = Shape<_1, _1, _1>{}; + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + const int K_TILE_MAX = min(N, K) / 64; + uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); + uint32_t tiles_in_n = (N + 64 - 1) / 64; + uint32_t linear_tile_idx = blockIdx.x; + uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; + uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + + auto mainloop_tiler = Shape<_128, _16, _64>{}; + auto epilogue_tiler = Shape<_128, _64, _64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gB_nk = + local_tile(mB, cluster_tile, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,k) + // Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + using TensorGC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, make_coord(_, _, _), Step<_1, _1, X>{})); - // Allocate SMEM - extern __shared__ char shared_memory[]; - using SharedStorage = SharedStorage; - SharedStorage &shared_storage = *reinterpret_cast(shared_memory); - Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), - sAlayout); // (MMA,MMA_M,MMA_N,PIPE) - Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), - sBlayout); // (MMA,MMA_N,MMA_K,PIPE) - - // - // MMA: Define C accumulators and A/B partitioning - // - - int block_rank_in_cluster = cute::block_rank_in_cluster(); - ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx - Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) - - auto mma_epilogue = make_tiled_mma( - SM100_MMA_F16BF16_SS{}, - Layout>{}); - ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); - - using TiledMmaEpilogue = decltype(mma_epilogue); - Tensor tCgA = thr_mma.partition_A(gA_mk); - // Allocate "fragments" -- these are actually umma smem descriptors - Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) - - auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0, 2>(ClusterTileShape{})); - auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0, 2>(epilogue_tiler)); - - auto bulk_tmem_mma = - TiledMMA::make_fragment_C(append(acc_shape_mma, Int{})); - - auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C( - append(acc_shape_epilogue, Int{})); - - TmemAllocator tmem_allocator{}; - cutlass::arch::NamedBarrier tmem_allocation_result_barrier( - 32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); - - Layout cta_layout_mnk = make_layout(cluster_shape); - Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); - - auto [tAgA, tAsA] = - tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), - group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); - - auto [tBgB, tBsB] = - tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), - group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); - - uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); - uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - - int warp_idx = cutlass::canonical_warp_idx_sync(); - - bool is_mma_warp = (warp_idx == 0); - bool is_dma_warp = (warp_idx == 1); - bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); - - // if (is_epilogue_warp && elect_one_sync()) { - // // prefetch to make the global amax in cache - // for (size_t i = 0; i < kernel_args.num_tensors; ++i) { - // cute::prefetch(raw_pointer_cast(kernel_args.global_amax_list[i])); - // } - // } - - typename MainloopPipeline::Params mainloop_pipeline_params; - if (is_dma_warp) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (is_mma_warp) { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; - mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; - mainloop_pipeline_params.initializing_warp = 0; - MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, - cluster_shape, cute::true_type{}, // Perform barrier init - cute::true_type{}); // Delay mask calculation - - MainloopPipelineState mainloop_pipe_consumer_state; - MainloopPipelineState mainloop_pipe_producer_state = - cutlass::make_producer_start_state(); - - using AccumulatorPipeline = - cutlass::PipelineUmmaAsync; - using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; - - AccumulatorPipelineState accumulator_pipe_consumer_state; - AccumulatorPipelineState accumulator_pipe_producer_state = - cutlass::make_producer_start_state(); - - typename AccumulatorPipeline::Params accumulator_pipeline_params; - if (is_mma_warp) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; - } - if (is_epilogue_warp) { - accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; - } - // Only one producer thread arrives on this barrier. - accumulator_pipeline_params.producer_arv_count = 1; - accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; - accumulator_pipeline_params.initializing_warp = 1; - AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params, - cluster_shape, - cute::true_type{}, // Perform barrier init - cute::true_type{}); // Delay mask calculation - - if (warp_idx == 2 && elect_one_sync()) { - cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); - } - __syncthreads(); - using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; - - if (is_dma_warp) { - if (elect_one_sync()) { - cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], - kTmaRhtTensorTransactionBytes); - copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), - tBsB(_, 0)); + using TensorGSFC = decltype(local_tile(std::declval(), decltype(epilogue_tiler){}, + make_coord(_, _, _), Step<_1, _1, X>{})); + + // Allocate SMEM + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage &shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), + sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), + sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + // + // MMA: Define C accumulators and A/B partitioning + // + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); + + using TiledMmaEpilogue = decltype(mma_epilogue); + Tensor tCgA = thr_mma.partition_A(gA_mk); + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0, 2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0, 2>(epilogue_tiler)); + + auto bulk_tmem_mma = + TiledMMA::make_fragment_C(append(acc_shape_mma, Int{})); + + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C( + append(acc_shape_epilogue, Int{})); + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier( + 32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = + tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = + tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA)); + + auto [tBgB, tBsB] = + tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); + + // if (is_epilogue_warp && elect_one_sync()) { + // // prefetch to make the global amax in cache + // for (size_t i = 0; i < kernel_args.num_tensors; ++i) { + // cute::prefetch(raw_pointer_cast(kernel_args.global_amax_list[i])); + // } + // } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params, + cluster_shape, cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + + using AccumulatorPipeline = + cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = + cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; } + if (is_epilogue_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, + accumulator_pipeline_params, cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + if (is_dma_warp) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], + kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0), + tBsB(_, 0)); + } - do { - bool is_first_wave = linear_tile_idx == blockIdx.x; - uint32_t skip_wait = is_first_wave; - auto tAgA_mk = tAgA(_, tile_idx_m, _); - int k_tile = 0; - auto barrier_token = - mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); - - CUTE_NO_UNROLL - while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { - int k_tile_idx_n = tile_idx_n + k_tile; - ++k_tile; - skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); - mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType *tma_barrier = - mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); - int write_stage = mainloop_pipe_producer_state.index(); - ++mainloop_pipe_producer_state; - barrier_token = + do { + bool is_first_wave = linear_tile_idx == blockIdx.x; + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_, tile_idx_m, _); + int k_tile = 0; + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); - if (cute::elect_one_sync()) { - copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), - tAsA(_, write_stage)); + + CUTE_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { + int k_tile_idx_n = tile_idx_n + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = + mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n), + tAsA(_, write_stage)); + } } - } - linear_tile_idx += gridDim.x; - tile_idx_m = linear_tile_idx % tiles_in_m; - tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; - } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); - mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); - } else if (is_mma_warp) { - mma.accumulate_ = UMMA::ScaleOut::Zero; - - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); - __syncwarp(); - tmem_allocation_result_barrier.arrive(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - bulk_tmem_mma.data() = tmem_base_ptr; - - cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); - do { - uint32_t skip_wait = K_TILE_MAX <= 0; - auto barrier_token = - mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - CUTE_NO_UNROLL - for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n;) { - mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - int read_stage = mainloop_pipe_consumer_state.index(); - auto tCrA_mk = tCrA(_, _, _, read_stage); - auto tCrB_nk = tCrB(_, _, 0, 0); - CUTE_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + auto barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n;) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_, _, _, read_stage); + auto tCrB_nk = tCrB(_, _, 0, 0); CUTE_UNROLL - for (int i = 0; i < 4; i++) { - auto accumulators = - bulk_tmem_mma(_, _, _, accumulator_pipe_producer_state.index() * 4 + i); - gemm(mma, tCrA_mk(_, _, k_block * 4 + i), tCrB_nk, accumulators); + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTE_UNROLL + for (int i = 0; i < 4; i++) { + auto accumulators = + bulk_tmem_mma(_, _, _, accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_, _, k_block * 4 + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; } - - accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); - ++accumulator_pipe_producer_state; + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + barrier_token = + mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); } - auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; - ++mainloop_pipe_consumer_state; - ++k_tile; - skip_wait = k_tile >= K_TILE_MAX; - barrier_token = - mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); - } - linear_tile_idx += gridDim.x; - tile_idx_m = linear_tile_idx % tiles_in_m; - tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; - } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); - tmem_allocator.release_allocation_lock(); - accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); - tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); - } else if (is_epilogue_warp) { - static constexpr int FragmentSize = 256 / sizeof_bits_v; - - tmem_allocation_result_barrier.arrive_and_wait(); - uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - bulk_tmem_epilogue.data() = tmem_base_ptr; - int thread_idx = threadIdx.x % 128; - - auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); - auto tiled_r2g = - make_tiled_copy_D(Copy_Atom{}, tiled_t2r); - auto thr_t2r = tiled_t2r.get_slice(thread_idx); - auto thr_r2g = tiled_r2g.get_slice(thread_idx); - - // NVFP4 non-E8 recipe constants and global scales - static constexpr float fp4_max = 6.0f; - static constexpr float fp4_max_inv = 1.0f / fp4_max; - - // get global amax pointer - int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); - float *global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id); - - TC *cur_output_colwise_ptr = reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); - TSFC *cur_output_colwise_scale_inv_ptr = - reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); - int cur_output_colwise_n = kernel_args.split_sections[tensor_id]; - - TensorC cur_mC = - cute::make_tensor(cute::subbyte_iterator(cur_output_colwise_ptr), - cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) - kernel_args.output_stride2d_list[tensor_id]); - - auto cur_sfc_shape = - make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); - - auto cur_sfc_stride = - make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); - - TensorSFC cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), - make_layout(cur_sfc_shape, cur_sfc_stride)); - - TensorGC cur_gC_mn = - local_tile(cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) - ); - - TensorGSFC cur_gSFC_mn = local_tile( - cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N-like) - ); - - Tensor tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); - - float global_amax_val = *global_amax_ptr; - float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - - // Scaling factor for fast math path - float global_encode_scale_multiplier = 1.0f; - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; - } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (is_epilogue_warp) { + static constexpr int FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int thread_idx = threadIdx.x % 128; + + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{})); + auto tiled_r2g = + make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto thr_r2g = tiled_r2g.get_slice(thread_idx); + + // NVFP4 non-E8 recipe constants and global scales + static constexpr float fp4_max = 6.0f; + static constexpr float fp4_max_inv = 1.0f / fp4_max; + + // get global amax pointer + int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64); + float *global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id); + + TC *cur_output_colwise_ptr = + reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + TSFC *cur_output_colwise_scale_inv_ptr = + reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + int cur_output_colwise_n = kernel_args.split_sections[tensor_id]; + + TensorC cur_mC = cute::make_tensor( + cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id]); + + auto cur_sfc_shape = + make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); + + auto cur_sfc_stride = + make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); + + TensorSFC cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride)); + + TensorGC cur_gC_mn = local_tile( + cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); + + TensorGSFC cur_gSFC_mn = local_tile( + cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + ); + + Tensor tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); + + float global_amax_val = *global_amax_ptr; + float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + + // Scaling factor for fast math path + float global_encode_scale_multiplier = 1.0f; + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } - float global_decode_scale = 1.0f / global_encode_scale; - - auto sfd_converter = cutlass::NumericConverter{}; - - do { - for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { - // get the starting index of current k-tile in global tensor, to query the correct global amax - int cur_k_tile_global_elem_idx = (tile_idx_n + k_tile) * 64; - int new_tensor_id = GetTensorId(&kernel_args, cur_k_tile_global_elem_idx); - // float* new_global_amax_ptr = GetGlobalAmaxPtr(&kernel_args, cur_k_tile_global_elem_idx); - global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, new_tensor_id); - // update the scaling factors when it's no longer the same amax pointer - // TODO(zhongbo): the math operations are very expensive - // since the kernel is persistent, we can have a cache for all the possible scaling factors - if (tensor_id != new_tensor_id) { - global_amax_val = *global_amax_ptr; - global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); - if constexpr (kUseFastMath) { - global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + float global_decode_scale = 1.0f / global_encode_scale; + + auto sfd_converter = cutlass::NumericConverter{}; + + do { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { + // get the starting index of current k-tile in global tensor, to query the correct global amax + int cur_k_tile_global_elem_idx = (tile_idx_n + k_tile) * 64; + int new_tensor_id = GetTensorId(&kernel_args, cur_k_tile_global_elem_idx); + // float* new_global_amax_ptr = GetGlobalAmaxPtr(&kernel_args, cur_k_tile_global_elem_idx); + global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, new_tensor_id); + // update the scaling factors when it's no longer the same amax pointer + // TODO(zhongbo): the math operations are very expensive + // since the kernel is persistent, we can have a cache for all the possible scaling factors + if (tensor_id != new_tensor_id) { + global_amax_val = *global_amax_ptr; + global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + if constexpr (kUseFastMath) { + global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; + } + global_decode_scale = 1.0f / global_encode_scale; + tensor_id = new_tensor_id; + // went through the cute operations to update the local tensors + cur_output_colwise_ptr = + reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); + cur_output_colwise_scale_inv_ptr = + reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + cur_output_colwise_n = kernel_args.split_sections[tensor_id]; + + cur_mC = cute::make_tensor( + cute::subbyte_iterator(cur_output_colwise_ptr), + cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) + kernel_args.output_stride2d_list[tensor_id]); + + cur_sfc_shape = + make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); + + cur_sfc_stride = + make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); + + cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), + make_layout(cur_sfc_shape, cur_sfc_stride)); + + cur_gC_mn = local_tile( + cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) + ); + + cur_gSFC_mn = local_tile(cur_mSFC, epilogue_tiler, make_coord(_, _, _), + Step<_1, _1, X>{} // (BLK_M, BLK_N-like) + ); + + tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); + } + // maybe udpated to the new tensor id + int tensor_start_elem = kernel_args.split_sections_range[tensor_id]; + int local_tile_idx_n = (cur_k_tile_global_elem_idx - tensor_start_elem) / 64; + + Tensor tCgC_mn = tCgC(_, _, _, tile_idx_m, local_tile_idx_n); + Tensor tCgSFC_mn = cur_gSFC_mn(_, _, tile_idx_m, local_tile_idx_n); + + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto tCtC = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = + make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrC = make_tensor(shape(tDgC)); + Tensor tTR_rAcc_frag = + recast>(coalesce(tTR_rAcc)); + Tensor tDrC_frag = recast>(coalesce(tDrC)); + + Tensor src = thr_r2g.retile_S(tDrC); + Tensor dst = thr_r2g.retile_D(tDgC); + + Tensor tCgSFC = make_tensor( + tCgSFC_mn.data(), make_layout(make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}))); + + Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); + Tensor tDrSFC = make_tensor(shape(tDgSFC)); + + static constexpr int NumVecs = size(tDgC) / VectorSize; + Tensor tC_rRowSFD_frg = recast>(tDrSFC); + + cutlass::maximum_absolute_value_reduction, + true> + amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtC, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + + ++accumulator_pipe_consumer_state; + + if constexpr (!kUseFastMath) { + // Downcast to BF16 for bit-wise compatibility with unfused + // kernels + auto convert_accum_to_bf16 = + cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = + cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); } - global_decode_scale = 1.0f / global_encode_scale; - tensor_id = new_tensor_id; - // went through the cute operations to update the local tensors - cur_output_colwise_ptr = - reinterpret_cast(kernel_args.output_colwise_list[tensor_id]); - cur_output_colwise_scale_inv_ptr = - reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); - cur_output_colwise_n = kernel_args.split_sections[tensor_id]; - - cur_mC = cute::make_tensor( - cute::subbyte_iterator(cur_output_colwise_ptr), - cute::make_shape(static_cast(M), cur_output_colwise_n), // (M, N_i) - kernel_args.output_stride2d_list[tensor_id]); - - cur_sfc_shape = - make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64)); - - cur_sfc_stride = - make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{})); - - cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr), - make_layout(cur_sfc_shape, cur_sfc_stride)); - - cur_gC_mn = local_tile( - cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N) - ); - - cur_gSFC_mn = local_tile(cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} - // (BLK_M, BLK_N-like) - ); - - tCgC = thr_mma_epilogue.partition_C(cur_gC_mn); - } - // maybe udpated to the new tensor id - int tensor_start_elem = kernel_args.split_sections_range[tensor_id]; - int local_tile_idx_n = (cur_k_tile_global_elem_idx - tensor_start_elem) / 64; - - Tensor tCgC_mn = tCgC(_, _, _, tile_idx_m, local_tile_idx_n); - Tensor tCgSFC_mn = cur_gSFC_mn(_, _, tile_idx_m, local_tile_idx_n); - - accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); - - auto tCtC = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index()); - Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - - Tensor tTR_rAcc = - make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) - Tensor tDrC = make_tensor(shape(tDgC)); - Tensor tTR_rAcc_frag = - recast>(coalesce(tTR_rAcc)); - Tensor tDrC_frag = recast>(coalesce(tDrC)); - - Tensor src = thr_r2g.retile_S(tDrC); - Tensor dst = thr_r2g.retile_D(tDgC); - - Tensor tCgSFC = make_tensor( - tCgSFC_mn.data(), make_layout(make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), - make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}))); - - Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); - Tensor tDrSFC = make_tensor(shape(tDgSFC)); - - static constexpr int NumVecs = size(tDgC) / VectorSize; - Tensor tC_rRowSFD_frg = recast>(tDrSFC); - - cutlass::maximum_absolute_value_reduction, - true> - amax_reduction; - cutlass::Array vec_maxs; - cutlass::Array pvscales; - // TMEM_LOAD - copy(tiled_t2r, tDtC, tTR_rAcc); - cutlass::arch::fence_view_async_tmem_load(); - - accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); - - ++accumulator_pipe_consumer_state; - - if constexpr (!kUseFastMath) { - // Downcast to BF16 for bit-wise compatibility with unfused - // kernels - auto convert_accum_to_bf16 = - cutlass::NumericArrayConverter{}; - auto convert_bf16_to_accum = - cutlass::NumericArrayConverter{}; - tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); - } - - auto compute_frgs = reinterpret_cast *>( - tTR_rAcc_frag.data()); - auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < NumVecs; v++) { - vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); - } - if constexpr (kUseFastMath) { - // Fast math: multiply with precomputed reciprocal - pvscales = cutlass::multiplies>{}( - vec_maxs, global_encode_scale_multiplier); - } else { - // Accurate math: perform division - pvscales = - cutlass::divides>{}(vec_maxs, fp4_max); - pvscales = cutlass::multiplies>{}( - pvscales, global_encode_scale); - } - auto pvscales_cvted = - cutlass::NumericArrayConverter{}(pvscales); - - tC_rRowSFD_frg(_0{}) = pvscales_cvted; - auto qpvscale_ups = cutlass::NumericArrayConverter{}( - tC_rRowSFD_frg(_0{})); - auto qpvscale_scaled = cutlass::multiplies>{}( - qpvscale_ups, global_decode_scale); - cutlass::Array acc_scales; - if constexpr (kUseFastMath) { - // Fast math: compute approximate reciprocal - acc_scales = - cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); - } else { - // Accurate math: compute reciprocal with division - acc_scales = - cutlass::divides>{}(1.0, qpvscale_scaled); - } + auto compute_frgs = reinterpret_cast *>( + tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } - // Initialize RNG for tile - const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; - - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - rng.init(rng_seed, rng_sequence, rng_offset); - uint4 random_uint4 = uint4{0, 0, 0, 0}; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < NumVecs; v++) { - auto acc_scale = cutlass::minimum_with_nan_propagation{}( - acc_scales[v], cutlass::platform::numeric_limits::max()); - // auto acc_scale = acc_scales[v]; - if constexpr (kEnableStochasticRounding) { - random_uint4 = rng.generate4(); - output_frgs[v] = StochasticNumericConverter( - cutlass::multiplies>{}( - compute_frgs[v], acc_scale), - reinterpret_cast *>(&random_uint4)); + if constexpr (kUseFastMath) { + // Fast math: multiply with precomputed reciprocal + pvscales = cutlass::multiplies>{}( + vec_maxs, global_encode_scale_multiplier); } else { - output_frgs[v] = cutlass::NumericArrayConverter{}( - cutlass::multiplies>{}( - compute_frgs[v], acc_scale)); + // Accurate math: perform division + pvscales = + cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}( + pvscales, global_encode_scale); + } + auto pvscales_cvted = + cutlass::NumericArrayConverter{}(pvscales); + + tC_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}( + tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}( + qpvscale_ups, global_decode_scale); + cutlass::Array acc_scales; + if constexpr (kUseFastMath) { + // Fast math: compute approximate reciprocal + acc_scales = + cutlass::reciprocal_approximate_ftz{}(qpvscale_scaled); + } else { + // Accurate math: compute reciprocal with division + acc_scales = cutlass::divides>{}( + 1.0, qpvscale_scaled); } - } - copy(tiled_r2g, src, dst); + // Initialize RNG for tile + const size_t rng_sequence = + thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}( + acc_scales[v], cutlass::platform::numeric_limits::max()); + // auto acc_scale = acc_scales[v]; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale), + reinterpret_cast *>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}( + cutlass::multiplies>{}( + compute_frgs[v], acc_scale)); + } + } - // copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrC, tDgC); + copy(tiled_r2g, src, dst); - copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); - } - linear_tile_idx += gridDim.x; - tile_idx_m = linear_tile_idx % tiles_in_m; - tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; - } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); - } + // copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrC, tDgC); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + } } } diff --git a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu index ed0b4b089f..77b483fdba 100644 --- a/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu @@ -1379,9 +1379,9 @@ void hadamard_transform_cast_fusion(const Tensor &input_, Tensor &output_, } // namespace transformer_engine void nvte_quantize_with_hadamard_transform(const NVTETensor input, NVTETensor output, - const NVTETensor hadamard_matrix, - const NVTEQuantizationConfig quant_config, - cudaStream_t stream) { + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_with_hadamard_transform); using namespace transformer_engine; QuantizationConfig quant_config_cpp; diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index cf7a1640b9..47452b34c4 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -72,9 +72,9 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE * \param[in] stream CUDA stream used for the operation. */ void nvte_quantize_with_hadamard_transform(const NVTETensor input, NVTETensor output, - const NVTETensor hadamard_matrix, - const NVTEQuantizationConfig quant_config, - cudaStream_t stream); + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); /*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split. * diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1a8069df09..3426160272 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2265,7 +2265,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 2. RHT followed by columnwise quantization & transpose NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_with_hadamard_transform(input.data(), out.data(), rht_matrix_nvte.data(), - quant_config, stream); + quant_config, stream); }); } else { // Use separate RNG state for columnwise to ensure different random numbers than rowwise From 2f6b001ef5047b02df1c0942ede4ea7350c7c9b5 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 18:09:10 -0800 Subject: [PATCH 20/22] try to fix compile CI again Signed-off-by: Zhongbo Zhu --- transformer_engine/common/util/ptx.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 5367d7e781..fc9e92cc2e 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -14,6 +14,10 @@ #include #include +#ifndef FP4_TYPE_SUPPORTED +#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) +#endif + #if CUDA_VERSION >= 12080 #include #endif // CUDA_VERSION >= 12080 From 67d4684732e8a07a075b41d5b7c02acad7909363 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 18:16:15 -0800 Subject: [PATCH 21/22] AI code review comments Signed-off-by: Zhongbo Zhu --- benchmarks/linear/benchmark_linear.py | 6 ++++-- transformer_engine/pytorch/csrc/extensions/cast.cpp | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/benchmarks/linear/benchmark_linear.py b/benchmarks/linear/benchmark_linear.py index b293c44fc9..4230db446d 100644 --- a/benchmarks/linear/benchmark_linear.py +++ b/benchmarks/linear/benchmark_linear.py @@ -282,6 +282,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, fwd_only=False): else: recipe_list = [args.recipe] + profiler_ctx = None if args.profile: hidden_dim_to_profile = 4096 if args.hidden_dim is None else args.hidden_dim output_dim_to_profile = 4096 if args.output_dim is None else args.output_dim @@ -293,7 +294,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, fwd_only=False): " fp8_sub_channel, mxfp8, nvfp4, or bf16" ) recipe_list = [args.recipe] - torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + profiler_ctx = torch.autograd.profiler.emit_nvtx(record_shapes=True) + profiler_ctx.__enter__() # Initialize a dataframe to store the results df_linears = pd.DataFrame() @@ -327,4 +329,4 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, fwd_only=False): print(df_linears) if args.profile: - torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) + profiler_ctx.__exit__(None, None, None) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 4459fdcc65..f4c7386297 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -999,7 +999,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // and inconsistently implemented. // What math is accelerated? Only the high precision math, so numerical impact is minimal // 1. replace x / y by x * (1/y) - // 2. replace 1 / x by reciporal_approximate_ftz(x) + // 2. replace 1 / x by reciprocal_approximate_ftz(x) // 3. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 3426160272..6aad0331ae 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2247,7 +2247,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // Fast math toggle: RHT transform can be accelerated // What math is accelerated? Only the high precision math, so numerical impact is minimal // 1. replace x / y by x * (1/y) - // 2. replace 1 / x by reciporal_approximate_ftz(x) + // 2. replace 1 / x by reciprocal_approximate_ftz(x) // 3. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); @@ -2259,6 +2259,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou if (this->with_rht) { if (eligible_for_rht_cast_fusion) { // fusion kernel requires passing in RHT matrix directly for maximum performance + NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, + "RHT matrix is not available."); auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); // Fusion kernel that does the following: // 1. Rowwise quantization From 999fe85c923ce5fda5cdbb4d4933bb951ef44186 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu Date: Tue, 3 Mar 2026 21:11:23 -0800 Subject: [PATCH 22/22] to pass oldest compile CI with cuda 12.1 Signed-off-by: Zhongbo Zhu --- transformer_engine/common/util/ptx.cuh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index fc9e92cc2e..f7611e60c5 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -14,13 +14,11 @@ #include #include -#ifndef FP4_TYPE_SUPPORTED -#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) -#endif +#include "common/common.h" -#if CUDA_VERSION >= 12080 +#if FP4_TYPE_SUPPORTED #include -#endif // CUDA_VERSION >= 12080 +#endif // FP4_TYPE_SUPPORTED #include "common/utils.cuh"