From 26ff5be0314381464d6a6eb3387d8994dd07a8b6 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 3 Mar 2026 19:27:40 -0800 Subject: [PATCH 1/6] Add unswizzling functions for scaling factors in swizzle module - Introduced `nvte_unswizzle_scaling_factors` to convert swizzled scaling factors back to row-major format. - Implemented `regs_unshuffle_with_bit_shifts` and `regs_unshuffle` for unshuffling operations in CUDA kernels. - Added `unswizzle_row_scaling_kernel_impl` and `unswizzle_col_scaling_kernel_impl` for handling unswizzling in row and column scaling respectively. These changes enhance the functionality of the swizzle module, enabling better handling of scaling factors in tensor operations. Signed-off-by: Abhishek --- .../include/transformer_engine/swizzle.h | 12 + transformer_engine/common/swizzle/swizzle.cu | 431 ++++++++++++++++++ 2 files changed, 443 insertions(+) diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 5e420b2d42..692d5f8e77 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -45,6 +45,18 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major + * + * \param[in] input Input tensor with swizzled scale_inv. + * \param[in,out] output Output tensor which hosts non-swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major in output. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM * * \param[in] input Input FP8 block-scaled tensor. diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 4425c4e9f7..2a477aa810 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -54,6 +54,31 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; } +template +__device__ inline void regs_unshuffle_with_bit_shifts(LType* regs_vec) { + // Inverse of regs_shuffle_with_bit_shifts + // inp, 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15] + // out, swapping byte to form new 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15] + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t new_regs[kVectorSize]; + int32_t* regs = reinterpret_cast(regs_vec); + +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { +#pragma unroll + for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) { + new_regs[i + j * N_TILE_PER_TD] = + ((regs[i * N_SF_PER_TD_PER_TILE + 0] >> 8 * j) & 0xFF) | + (((regs[i * N_SF_PER_TD_PER_TILE + 1] >> 8 * j) & 0xFF) << 8) | + (((regs[i * N_SF_PER_TD_PER_TILE + 2] >> 8 * j) & 0xFF) << 16) | + (((regs[i * N_SF_PER_TD_PER_TILE + 3] >> 8 * j) & 0xFF) << 24); + } + } +#pragma unroll + for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; +} template __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, @@ -170,6 +195,23 @@ __device__ inline void regs_shuffle(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; } +// Inverse of regs_shuffle. +template +__device__ inline void regs_unshuffle(LType* regs_vec) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + if constexpr (N_TILE_PER_TD == 1) return; + + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t tmp[kVectorSize]; + int32_t* ptr = reinterpret_cast(regs_vec); +#pragma unroll + for (int i = 0; i < kVectorSize; i++) + tmp[i % N_SF_PER_TD_PER_TILE * N_TILE_PER_TD + i / N_SF_PER_TD_PER_TILE] = ptr[i]; + +#pragma unroll + for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; +} + template __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, @@ -239,6 +281,142 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, } } +template +__device__ void unswizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M; + + int n_tiles_in_tb = N_TILES_IN_TB; + const int K_i32 = K / 4; + if (bid_x == grid_dim_x - 1) { + n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; + } + + const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + + bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + const int* input_i32 = reinterpret_cast(input) + input_offset; + const int output_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; + uint8_t* output_u8 = reinterpret_cast(output); + + extern __shared__ int4 slm_v4i[]; + + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; + const int4* input_v4i = reinterpret_cast(input_i32); +#pragma unroll + for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) { + slm_v4i[i] = input_v4i[i]; + } + __syncthreads(); + + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { + reinterpret_cast(regs_vec)[i] = + slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y]; + } + + regs_unshuffle(regs_vec); + +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (output_offset + thread_offset) * sizeof(int) + j; + if (index / K < original_M && index % K < original_K) { + output_u8[index / K * original_K + index % K] = reinterpret_cast(regs_vec + i)[j]; + } + } + } + } +} + +template +__device__ void unswizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4; + constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K; + + const int M_i32 = M / 4; + const int K_i32 = K; + + int m_tiles_in_tb = N_TILE_PER_TD; + int k_tiles_in_tb = TB_DIM; + if (bid_x == grid_dim_x - 1) { + k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; + } + if (bid_y == grid_dim_y - 1) { + m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; + } + + const int32_t* input_i32[N_TILE_PER_TD]; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + input_i32[i] = reinterpret_cast(input) + bid_x * TB_DIM * SF_TILE_SIZE_I32 + + (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + } + const int output_offset = + bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + uint8_t* output_u8 = reinterpret_cast(output); + + extern __shared__ int slm[]; + + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + __align__(16) int4* input_v4i = reinterpret_cast(const_cast(input_i32[i])); + __align__(16) int4* slm_v4i = + reinterpret_cast(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; + j += blockDim.x * blockDim.y) { + slm_v4i[j] = input_v4i[j]; + } + } + __syncthreads(); + + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && + threadIdx.y < k_tiles_in_tb) { + int tM = threadIdx.x * N_SF_PER_TD; + int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 + + tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int i = 0; i < N_SF_PER_TD; i++) { + reinterpret_cast(regs_vec)[i] = + slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 + + ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32]; + } + + regs_unshuffle_with_bit_shifts(regs_vec); + +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int thread_offset = + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (output_offset + thread_offset) * sizeof(int) + j; + if (index / M < original_K && index % M < original_M) { + output_u8[index / M * original_M + index % M] = reinterpret_cast(regs_vec + i)[j]; + } + } + } + } +} template __global__ void __launch_bounds__(TB_DIM* TB_DIM) swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, @@ -247,6 +425,23 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); } +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + unswizzle_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K, const bool row_scaling) { + const int bid_x = blockIdx.x; + const int bid_y = blockIdx.y; + const int grid_dim_x = gridDim.x; + const int grid_dim_y = gridDim.y; + if (row_scaling) { + unswizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + unswizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB struct MultiSwizzleArgs { // (input) Data buffers for input scaling factors @@ -816,6 +1011,236 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args, vec_load_size, false, stream); } } + +void unswizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { + const auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + NVTE_CHECK(input->with_gemm_swizzled_scales, "Expected input tensor with swizzled scales."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Expected output tensor in row-major compact format."); + + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", + to_string(input->dtype()), ")."); + break; + case NVTE_NVFP4_1D_SCALING: + NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ", + to_string(input->dtype()), ")."); + break; + default: + NVTE_ERROR("Invalid scaling mode"); + } + + const bool has_rowwise_scale_inv = input->scale_inv.has_data(); + const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); + NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, + "Input tensor has both row-wise and column-wise scaling factors"); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } + + int m{0}, k{0}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, + "."); + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, + "."); + m = input->columnwise_scale_inv.shape[0]; + k = input->columnwise_scale_inv.shape[1]; + } + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); + } + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + + if (has_rowwise_scale_inv) { + NVTE_CHECK(output->scale_inv.has_data(), + "Output tensor does not have row-wise scaling factors."); + } + if (has_columnwise_scale_inv) { + NVTE_CHECK(output->columnwise_scale_inv.has_data(), + "Output tensor does not have column-wise scaling factors."); + } + + bool rowwise_unswizzle{false}, columnwise_unswizzle{false}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + rowwise_unswizzle = has_rowwise_scale_inv; + columnwise_unswizzle = has_columnwise_scale_inv; + break; + } + case NVTE_NVFP4_1D_SCALING: { + rowwise_unswizzle = true; + columnwise_unswizzle = false; + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); + } + + const dim3 block_size(TB_DIM, TB_DIM); + const int num_tiles_m = m / SF_TILE_DIM_M; + const int num_tiles_k = k / SF_TILE_DIM_K; + + if (rowwise_unswizzle) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + int original_M{0}, original_K{0}; + void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + original_M = output->flat_first_dim(); + original_K = output->flat_last_dim() / MXFP8_BLOCK_SIZE; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + NVTE_CHECK(static_cast(original_M) * original_K == output->scale_inv.numel(), + "Expected output tensor to have ", static_cast(original_M) * original_K, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + break; + } + case NVTE_NVFP4_1D_SCALING: { + if (has_rowwise_scale_inv) { + original_M = output->flat_first_dim(); + original_K = output->flat_last_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + NVTE_CHECK( + static_cast(original_M) * original_K == output->scale_inv.numel(), + "Expected output tensor to have ", static_cast(original_M) * original_K, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + } else if (has_columnwise_scale_inv) { + original_M = output->flat_last_dim(); + original_K = output->flat_first_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->columnwise_scale_inv.dptr; + output_scale_inv_ptr = output->columnwise_scale_inv.dptr; + NVTE_CHECK( + static_cast(original_M) * original_K == output->columnwise_scale_inv.numel(), + "Expected output tensor to have ", static_cast(original_M) * original_K, + " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, + "."); + } + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); + } + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K, true); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K, true); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K, true); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + if (columnwise_unswizzle) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = output->flat_last_dim(); + const int original_K = output->flat_first_dim() / MXFP8_BLOCK_SIZE; + NVTE_CHECK(static_cast(original_M) * original_K == output->columnwise_scale_inv.numel(), + "Expected output tensor to have ", static_cast(original_M) * original_K, + " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, + "."); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K, false); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K, false); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K, false); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} + } // namespace transformer_engine /* @@ -841,3 +1266,9 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen } multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); } + +void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_unswizzle_scaling_factors); + using namespace transformer_engine; + unswizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); +} From 6a064cf4570f562b9f6cd36024bcc2426f45c130 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 3 Mar 2026 19:32:44 -0800 Subject: [PATCH 2/6] Add swizzle/unswizzle roundtrip test for scaling factors These enhancements tests the changes introduced for unswizzling Signed-off-by: Abhishek --- tests/cpp/operator/test_swizzle.cu | 84 ++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 694b348a9b..a6f98228f8 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -164,3 +164,87 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<2>(info.param)); return name; }); + + void performTestSwizzleUnswizzleRoundtrip(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor swizzled("swizzled", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + swizzled.set_with_gemm_swizzled_scales(true); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + nvte_swizzle_scaling_factors(input.data(), swizzled.data(), 0); + nvte_unswizzle_scaling_factors(swizzled.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + input.to_cpu(); + output.to_cpu(); + if (rowwise) { + compareResults("roundtrip_rowwise", output.rowwise_cpu_scale_inv_ptr(), + input.rowwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("roundtrip_columnwise", output.columnwise_cpu_scale_inv_ptr(), + input.columnwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); + } + } + + class SwizzleUnswizzleRoundtripTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + + TEST_P(SwizzleUnswizzleRoundtripTestSuite, TestSwizzleUnswizzleRoundtrip) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestSwizzleUnswizzleRoundtrip(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); + } + + INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleUnswizzleRoundtripTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "roundtrip_ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); From 7d1567eea80f1b5180e2abc2d8db2e00bcf3fff2 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 3 Mar 2026 20:05:04 -0800 Subject: [PATCH 3/6] Added another unswizzling functionality test for scaling factors - Introduced `compute_ref_unswizzle` to handle the conversion of swizzled scaling factors back to their original format. - Added `performTestUnswizzle1D` to validate the unswizzling process with various scaling modes. - Created `UnswizzleTestSuite` for comprehensive testing of unswizzling operations. Signed-off-by: Abhishek --- tests/cpp/operator/test_swizzle.cu | 281 ++++++++++++++++++++--------- 1 file changed, 198 insertions(+), 83 deletions(-) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index a6f98228f8..58c544f08b 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -56,6 +56,35 @@ void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output, } } +template +void compute_ref_unswizzle(const uint8_t *h_input, uint8_t *h_output, + const size_t M, const size_t K) { + + constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4; + constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4; + constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K; + + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + + int tile_id_m = m / SF_TILE_DIM_M; + int tile_id_k = k / SF_TILE_DIM_K; + int m_in_tile = m % SF_TILE_DIM_M; + int k_in_tile = k % SF_TILE_DIM_K; + + int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M; + int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile; + + int tile_input_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE; + int in_index = tile_input_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile; + if constexpr(row_scaling) + h_output[k + m * K] = h_input[in_index]; + else + h_output[k * M + m] = h_input[in_index]; + } + } +} + void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { using namespace test; @@ -110,6 +139,59 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row } } +void performTestUnswizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + input.set_with_gemm_swizzled_scales(true); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + std::unique_ptr ref_output = std::make_unique(scale_shape[0] * scale_shape[1]); + + nvte_unswizzle_scaling_factors(input.data(), output.data(), 0); + + if (rowwise) + compute_ref_unswizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0], scale_shape[1]); + else + compute_ref_unswizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[1], scale_shape[0]); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + output.to_cpu(); + if (rowwise) { + compareResults("output_unswizzle", output.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("output_unswizzle", output.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } +} + class SwizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; @@ -126,6 +208,21 @@ TEST_P(SwizzleTestSuite, TestSwizzle) { transa); } +class UnswizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + +TEST_P(UnswizzleTestSuite, TestUnswizzle) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestUnswizzle1D(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); +} + namespace { std::vector> num_tiles = { @@ -165,86 +262,104 @@ INSTANTIATE_TEST_SUITE_P( return name; }); - void performTestSwizzleUnswizzleRoundtrip(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { - using namespace test; - - int SF_MODE_X, SF_MODE_Y; - if (rowwise) { - SF_MODE_X = 1; - SF_MODE_Y = 32; - } - if (columnwise) { - SF_MODE_X = 32; - SF_MODE_Y = 1; - } - - if ((rowwise && columnwise) || !(rowwise || columnwise)){ - GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + - std::to_string(SF_MODE_Y) + "is not implemented."; - } - - DType dtype = DType::kFloat8E4M3; - - const size_t M = num_tiles_M * MAT_TILE_DIM_M; - const size_t K = num_tiles_K * MAT_TILE_DIM_K; - const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; - - const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; - - Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); - Tensor swizzled("swizzled", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); - swizzled.set_with_gemm_swizzled_scales(true); - Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); - - fillUniform(&input); - - nvte_swizzle_scaling_factors(input.data(), swizzled.data(), 0); - nvte_unswizzle_scaling_factors(swizzled.data(), output.data(), 0); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - input.to_cpu(); - output.to_cpu(); - if (rowwise) { - compareResults("roundtrip_rowwise", output.rowwise_cpu_scale_inv_ptr(), - input.rowwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); - } else { - compareResults("roundtrip_columnwise", output.columnwise_cpu_scale_inv_ptr(), - input.columnwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); - } - } - - class SwizzleUnswizzleRoundtripTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; - - TEST_P(SwizzleUnswizzleRoundtripTestSuite, TestSwizzleUnswizzleRoundtrip) { - using namespace transformer_engine; - using namespace test; - - const auto num_tiles = std::get<0>(GetParam()); - const auto scaling_mode = std::get<1>(GetParam()); - const auto transa = std::get<2>(GetParam()); - - performTestSwizzleUnswizzleRoundtrip(num_tiles.first, num_tiles.second, - scaling_mode.first, scaling_mode.second, - transa); - } - - INSTANTIATE_TEST_SUITE_P( - OperatorTest, - SwizzleUnswizzleRoundtripTestSuite, - ::testing::Combine( - ::testing::ValuesIn(num_tiles), - ::testing::ValuesIn(scaling_mode), - ::testing::ValuesIn(transa) - ), - [](const testing::TestParamInfo& info) { - std::string name = "roundtrip_ntiles" + - std::to_string(std::get<0>(info.param).first) + "X" + - std::to_string(std::get<0>(info.param).second) + "smode" + - std::to_string(std::get<1>(info.param).first) + "X"+ - std::to_string(std::get<1>(info.param).second) + "trans" + - std::to_string(std::get<2>(info.param)); - return name; - }); +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + UnswizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); + +void performTestSwizzleUnswizzleRoundtrip(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor swizzled("swizzled", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + swizzled.set_with_gemm_swizzled_scales(true); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + nvte_swizzle_scaling_factors(input.data(), swizzled.data(), 0); + nvte_unswizzle_scaling_factors(swizzled.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + input.to_cpu(); + output.to_cpu(); + if (rowwise) { + compareResults("roundtrip_rowwise", output.rowwise_cpu_scale_inv_ptr(), + input.rowwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("roundtrip_columnwise", output.columnwise_cpu_scale_inv_ptr(), + input.columnwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); + } +} + +class SwizzleUnswizzleRoundtripTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + +TEST_P(SwizzleUnswizzleRoundtripTestSuite, TestSwizzleUnswizzleRoundtrip) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestSwizzleUnswizzleRoundtrip(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleUnswizzleRoundtripTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "roundtrip_ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); From fd3ff05dd2260be938d406a205492b5905474f5e Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 3 Mar 2026 20:32:31 -0800 Subject: [PATCH 4/6] Moved swizzle_row_scaling_kernel implementation at its original place - Moved the definition of `swizzle_row_scaling_kernel` to a new location for better organization. - Ensured the kernel implementation is now properly defined and accessible for scaling operations in the swizzle module. Signed-off-by: Abhishek --- transformer_engine/common/swizzle/swizzle.cu | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 2a477aa810..b887f0c0bf 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -417,13 +417,6 @@ __device__ void unswizzle_col_scaling_kernel_impl(const void* input, void* outpu } } } -template -__global__ void __launch_bounds__(TB_DIM* TB_DIM) - swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, - const int original_M, const int original_K) { - swizzle_row_scaling_kernel_impl( - input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); -} template __global__ void __launch_bounds__(TB_DIM* TB_DIM) @@ -442,6 +435,14 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) } } +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} + constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB struct MultiSwizzleArgs { // (input) Data buffers for input scaling factors From 64b86d79fe2bb4089d68903b8ab50ba756ad73fa Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 3 Mar 2026 20:37:06 -0800 Subject: [PATCH 5/6] Add multi-tensor unswizzling functions for scaling factors - Introduced `multi_tensor_unswizzle_scaling_factors` to convert swizzled scaling factors back to their original row-major format. - Implemented CUDA kernels for unswizzling in both row and column scaling, enhancing the swizzle module's functionality. - Updated the launch function to handle multiple tensor unswizzling operations efficiently. These changes improve the handling of scaling factors in tensor operations, ensuring better performance and organization within the swizzle module. Signed-off-by: Abhishek --- .../include/transformer_engine/swizzle.h | 16 + transformer_engine/common/swizzle/swizzle.cu | 310 ++++++++++++++++++ 2 files changed, 326 insertions(+) diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 692d5f8e77..a26d378437 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -57,6 +57,22 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. */ void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major + * + * \param[in] inputs Input tensors with swizzled scale_inv. + * \param[in,out] outputs Output tensors which hosts non-swizzled scale_inv. + * \param[in] num_tensors Number of input and output tensors. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major in output. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream); + /*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM * * \param[in] input Input FP8 block-scaled tensor. diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index b887f0c0bf..92beecd95c 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -464,6 +464,63 @@ struct MultiSwizzleArgs { int num_tensors; }; +template +__global__ void multi_tensor_unswizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int flat_offset = bid - kernel_args.block_range[tensor_id]; + const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB); + const int grid_dim_y = num_tiles_m; + const int bid_x = flat_offset / grid_dim_y; + const int bid_y = flat_offset % grid_dim_y; + + unswizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +template +__global__ void multi_tensor_unswizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) { + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int flat_offset = bid - kernel_args.block_range[tensor_id]; + const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM); + const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD); + const int bid_x = flat_offset / grid_dim_y; + const int bid_y = flat_offset % grid_dim_y; + + unswizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + template __global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { // Find tensor corresponding to block @@ -843,6 +900,89 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void launch_multi_tensor_unswizzle_scaling_factors(MultiSwizzleArgs& kernel_args, + const int vec_load_size, const bool is_rowwise, + cudaStream_t stream) { + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + + int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + if (num_blocks > 0) { + dim3 block_size(TB_DIM, TB_DIM); + if (is_rowwise) { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} + void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); @@ -1242,6 +1382,163 @@ void unswizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t } } +void multi_tensor_unswizzle_scaling_factors(const std::vector& input, + const std::vector& output, + cudaStream_t stream) { + size_t num_tensors = input.size(); + const auto& first_scaling_mode = input[0]->scaling_mode; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + + bool all_has_data = true; + bool all_has_columnwise_data = true; + bool all_nvfp4 = true; + for (size_t i = 0; i < num_tensors; i++) { + const auto scaling_mode = input[i]->scaling_mode; + const auto is_fp8 = is_fp8_dtype(input[i]->dtype()); + const auto is_fp4 = is_fp4_dtype(input[i]->dtype()); + + NVTE_CHECK(scaling_mode == first_scaling_mode, + "All tensors should have the same scaling mode in multi-tensor unswizzle."); + NVTE_CHECK( + (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), + "Not implemented scaling mode " + to_string(scaling_mode) + "."); + NVTE_CHECK(input[i]->with_gemm_swizzled_scales, + "Expected input tensors with scales in GEMM swizzled format."); + NVTE_CHECK(!output[i]->with_gemm_swizzled_scales, + "Expected output tensors with scales in compact format."); + NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty."); + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + + all_has_data = all_has_data && input[i]->scale_inv.has_data(); + all_has_columnwise_data = + (all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data()); + all_nvfp4 = all_nvfp4 && is_nvfp4_scaling(scaling_mode); + } + NVTE_CHECK(all_has_data || all_has_columnwise_data, + "All tensors should have data or columnwise data."); + NVTE_CHECK(!all_has_data || !all_has_columnwise_data, + "All tensors have both data and columnwise data."); + + const bool rowwise_swizzle = all_has_data || all_nvfp4; + const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; + + if (rowwise_swizzle) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + int m, k; + if (all_has_data) { + m = input[i]->scale_inv.shape[0]; + k = input[i]->scale_inv.shape[1]; + } else { + NVTE_CHECK(all_nvfp4, + "When doing rowwise unswizzle with columnwise data, it has to be NVFP4"); + m = input[i]->columnwise_scale_inv.shape[0]; + k = input[i]->columnwise_scale_inv.shape[1]; + } + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + + if (all_has_data) { + NVTE_CHECK(output[i]->scale_inv.has_data(), "Output tensor ", i, + " does not have row-wise scaling factors."); + NVTE_CHECK(m * k == output[i]->scale_inv.numel(), "Expected output tensor ", i, " to have ", + m * k, " row-wise scaling factors, but got shape=", output[i]->scale_inv.shape, + "."); + } + if (all_has_columnwise_data) { + NVTE_CHECK(output[i]->columnwise_scale_inv.has_data(), "Output tensor ", i, + " does not have column-wise scaling factors."); + NVTE_CHECK(m * k == output[i]->columnwise_scale_inv.numel(), "Expected output tensor ", i, + " to have ", m * k, " column-wise scaling factors, but got shape=", + output[i]->columnwise_scale_inv.shape, "."); + } + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + if (!all_nvfp4 || all_has_data) { + int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.original_m_list[pos] = output[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = output[i]->flat_last_dim() / block_scale_size; + } else { + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.original_m_list[pos] = output[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = output[i]->flat_first_dim() / NVFP4_BLOCK_SIZE; + } + kernel_args.num_tensors++; + } + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + } + + if (columnwise_swizzle) { + NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); + + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->columnwise_scale_inv.shape[1]; + const int k = input[i]->columnwise_scale_inv.shape[0]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = output[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = output[i]->flat_first_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + } +} } // namespace transformer_engine /* @@ -1273,3 +1570,16 @@ void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, c using namespace transformer_engine; unswizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); } + +void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_unswizzle_scaling_factors); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_unswizzle_scaling_factors(input_list, output_list, stream); +} From 621bc1669208cf7d6ee84a5af04b80740b359bbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 05:09:59 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/swizzle/swizzle.cu | 28 +++++++++++--------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 92beecd95c..716d40e816 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -299,8 +299,8 @@ __device__ void unswizzle_row_scaling_kernel_impl(const void* input, void* outpu n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + - bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + const int input_offset = + bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; const int* input_i32 = reinterpret_cast(input) + input_offset; const int output_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; uint8_t* output_u8 = reinterpret_cast(output); @@ -331,7 +331,8 @@ __device__ void unswizzle_row_scaling_kernel_impl(const void* input, void* outpu for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (output_offset + thread_offset) * sizeof(int) + j; if (index / K < original_M && index % K < original_K) { - output_u8[index / K * original_K + index % K] = reinterpret_cast(regs_vec + i)[j]; + output_u8[index / K * original_K + index % K] = + reinterpret_cast(regs_vec + i)[j]; } } } @@ -411,7 +412,8 @@ __device__ void unswizzle_col_scaling_kernel_impl(const void* input, void* outpu for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { const int index = (output_offset + thread_offset) * sizeof(int) + j; if (index / M < original_K && index % M < original_M) { - output_u8[index / M * original_M + index % M] = reinterpret_cast(regs_vec + i)[j]; + output_u8[index / M * original_M + index % M] = + reinterpret_cast(regs_vec + i)[j]; } } } @@ -902,8 +904,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, template void launch_multi_tensor_unswizzle_scaling_factors(MultiSwizzleArgs& kernel_args, - const int vec_load_size, const bool is_rowwise, - cudaStream_t stream) { + const int vec_load_size, const bool is_rowwise, + cudaStream_t stream) { int n_tiles_in_tb = TB_DIM * vec_load_size; int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); for (size_t j = 0; j < kernel_args.num_tensors; j++) { @@ -1161,7 +1163,8 @@ void unswizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t CheckInputTensor(*input, "scaling_factor_input"); CheckInputTensor(*output, "scaling_factor_output"); NVTE_CHECK(input->with_gemm_swizzled_scales, "Expected input tensor with swizzled scales."); - NVTE_CHECK(!output->with_gemm_swizzled_scales, "Expected output tensor in row-major compact format."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "Expected output tensor in row-major compact format."); switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: @@ -1280,10 +1283,10 @@ void unswizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t original_K = output->flat_last_dim() / NVFP4_BLOCK_SIZE; input_scale_inv_ptr = input->scale_inv.dptr; output_scale_inv_ptr = output->scale_inv.dptr; - NVTE_CHECK( - static_cast(original_M) * original_K == output->scale_inv.numel(), - "Expected output tensor to have ", static_cast(original_M) * original_K, - " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + NVTE_CHECK(static_cast(original_M) * original_K == output->scale_inv.numel(), + "Expected output tensor to have ", + static_cast(original_M) * original_K, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); } else if (has_columnwise_scale_inv) { original_M = output->flat_last_dim(); original_K = output->flat_first_dim() / NVFP4_BLOCK_SIZE; @@ -1565,7 +1568,8 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); } -void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { +void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { NVTE_API_CALL(nvte_unswizzle_scaling_factors); using namespace transformer_engine; unswizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);