diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 694b348a9b..58c544f08b 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -56,6 +56,35 @@ void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output, } } +template +void compute_ref_unswizzle(const uint8_t *h_input, uint8_t *h_output, + const size_t M, const size_t K) { + + constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4; + constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4; + constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K; + + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + + int tile_id_m = m / SF_TILE_DIM_M; + int tile_id_k = k / SF_TILE_DIM_K; + int m_in_tile = m % SF_TILE_DIM_M; + int k_in_tile = k % SF_TILE_DIM_K; + + int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M; + int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile; + + int tile_input_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE; + int in_index = tile_input_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile; + if constexpr(row_scaling) + h_output[k + m * K] = h_input[in_index]; + else + h_output[k * M + m] = h_input[in_index]; + } + } +} + void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { using namespace test; @@ -110,6 +139,59 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row } } +void performTestUnswizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + input.set_with_gemm_swizzled_scales(true); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + std::unique_ptr ref_output = std::make_unique(scale_shape[0] * scale_shape[1]); + + nvte_unswizzle_scaling_factors(input.data(), output.data(), 0); + + if (rowwise) + compute_ref_unswizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0], scale_shape[1]); + else + compute_ref_unswizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[1], scale_shape[0]); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + output.to_cpu(); + if (rowwise) { + compareResults("output_unswizzle", output.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("output_unswizzle", output.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } +} + class SwizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; @@ -126,6 +208,21 @@ TEST_P(SwizzleTestSuite, TestSwizzle) { transa); } +class UnswizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + +TEST_P(UnswizzleTestSuite, TestUnswizzle) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestUnswizzle1D(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); +} + namespace { std::vector> num_tiles = { @@ -164,3 +261,105 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<2>(info.param)); return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + UnswizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); + +void performTestSwizzleUnswizzleRoundtrip(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor swizzled("swizzled", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + swizzled.set_with_gemm_swizzled_scales(true); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + nvte_swizzle_scaling_factors(input.data(), swizzled.data(), 0); + nvte_unswizzle_scaling_factors(swizzled.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + input.to_cpu(); + output.to_cpu(); + if (rowwise) { + compareResults("roundtrip_rowwise", output.rowwise_cpu_scale_inv_ptr(), + input.rowwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("roundtrip_columnwise", output.columnwise_cpu_scale_inv_ptr(), + input.columnwise_cpu_scale_inv_ptr(), scale_shape[0] * scale_shape[1]); + } +} + +class SwizzleUnswizzleRoundtripTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + +TEST_P(SwizzleUnswizzleRoundtripTestSuite, TestSwizzleUnswizzleRoundtrip) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestSwizzleUnswizzleRoundtrip(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleUnswizzleRoundtripTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "roundtrip_ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 5e420b2d42..a26d378437 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -45,6 +45,34 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major + * + * \param[in] input Input tensor with swizzled scale_inv. + * \param[in,out] output Output tensor which hosts non-swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major in output. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major + * + * \param[in] inputs Input tensors with swizzled scale_inv. + * \param[in,out] outputs Output tensors which hosts non-swizzled scale_inv. + * \param[in] num_tensors Number of input and output tensors. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major in output. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream); + /*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM * * \param[in] input Input FP8 block-scaled tensor. diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 4425c4e9f7..716d40e816 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -54,6 +54,31 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; } +template +__device__ inline void regs_unshuffle_with_bit_shifts(LType* regs_vec) { + // Inverse of regs_shuffle_with_bit_shifts + // inp, 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15] + // out, swapping byte to form new 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15] + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t new_regs[kVectorSize]; + int32_t* regs = reinterpret_cast(regs_vec); + +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { +#pragma unroll + for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) { + new_regs[i + j * N_TILE_PER_TD] = + ((regs[i * N_SF_PER_TD_PER_TILE + 0] >> 8 * j) & 0xFF) | + (((regs[i * N_SF_PER_TD_PER_TILE + 1] >> 8 * j) & 0xFF) << 8) | + (((regs[i * N_SF_PER_TD_PER_TILE + 2] >> 8 * j) & 0xFF) << 16) | + (((regs[i * N_SF_PER_TD_PER_TILE + 3] >> 8 * j) & 0xFF) << 24); + } + } +#pragma unroll + for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; +} template __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, @@ -170,6 +195,23 @@ __device__ inline void regs_shuffle(LType* regs_vec) { for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; } +// Inverse of regs_shuffle. +template +__device__ inline void regs_unshuffle(LType* regs_vec) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + if constexpr (N_TILE_PER_TD == 1) return; + + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t tmp[kVectorSize]; + int32_t* ptr = reinterpret_cast(regs_vec); +#pragma unroll + for (int i = 0; i < kVectorSize; i++) + tmp[i % N_SF_PER_TD_PER_TILE * N_TILE_PER_TD + i / N_SF_PER_TD_PER_TILE] = ptr[i]; + +#pragma unroll + for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; +} + template __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, const int K, const int original_M, @@ -239,6 +281,162 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, } } +template +__device__ void unswizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M; + + int n_tiles_in_tb = N_TILES_IN_TB; + const int K_i32 = K / 4; + if (bid_x == grid_dim_x - 1) { + n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; + } + + const int input_offset = + bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + const int* input_i32 = reinterpret_cast(input) + input_offset; + const int output_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; + uint8_t* output_u8 = reinterpret_cast(output); + + extern __shared__ int4 slm_v4i[]; + + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; + const int4* input_v4i = reinterpret_cast(input_i32); +#pragma unroll + for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) { + slm_v4i[i] = input_v4i[i]; + } + __syncthreads(); + + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { + reinterpret_cast(regs_vec)[i] = + slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y]; + } + + regs_unshuffle(regs_vec); + +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (output_offset + thread_offset) * sizeof(int) + j; + if (index / K < original_M && index % K < original_K) { + output_u8[index / K * original_K + index % K] = + reinterpret_cast(regs_vec + i)[j]; + } + } + } + } +} + +template +__device__ void unswizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4; + constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K; + + const int M_i32 = M / 4; + const int K_i32 = K; + + int m_tiles_in_tb = N_TILE_PER_TD; + int k_tiles_in_tb = TB_DIM; + if (bid_x == grid_dim_x - 1) { + k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; + } + if (bid_y == grid_dim_y - 1) { + m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; + } + + const int32_t* input_i32[N_TILE_PER_TD]; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + input_i32[i] = reinterpret_cast(input) + bid_x * TB_DIM * SF_TILE_SIZE_I32 + + (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + } + const int output_offset = + bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + uint8_t* output_u8 = reinterpret_cast(output); + + extern __shared__ int slm[]; + + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + __align__(16) int4* input_v4i = reinterpret_cast(const_cast(input_i32[i])); + __align__(16) int4* slm_v4i = + reinterpret_cast(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; + j += blockDim.x * blockDim.y) { + slm_v4i[j] = input_v4i[j]; + } + } + __syncthreads(); + + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && + threadIdx.y < k_tiles_in_tb) { + int tM = threadIdx.x * N_SF_PER_TD; + int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 + + tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int i = 0; i < N_SF_PER_TD; i++) { + reinterpret_cast(regs_vec)[i] = + slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 + + ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32]; + } + + regs_unshuffle_with_bit_shifts(regs_vec); + +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + const int thread_offset = + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (output_offset + thread_offset) * sizeof(int) + j; + if (index / M < original_K && index % M < original_M) { + output_u8[index / M * original_M + index % M] = + reinterpret_cast(regs_vec + i)[j]; + } + } + } + } +} + +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + unswizzle_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K, const bool row_scaling) { + const int bid_x = blockIdx.x; + const int bid_y = blockIdx.y; + const int grid_dim_x = gridDim.x; + const int grid_dim_y = gridDim.y; + if (row_scaling) { + unswizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } else { + unswizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); + } +} + template __global__ void __launch_bounds__(TB_DIM* TB_DIM) swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, @@ -268,6 +466,63 @@ struct MultiSwizzleArgs { int num_tensors; }; +template +__global__ void multi_tensor_unswizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int flat_offset = bid - kernel_args.block_range[tensor_id]; + const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB); + const int grid_dim_y = num_tiles_m; + const int bid_x = flat_offset / grid_dim_y; + const int bid_y = flat_offset % grid_dim_y; + + unswizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +template +__global__ void multi_tensor_unswizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) { + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int flat_offset = bid - kernel_args.block_range[tensor_id]; + const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM); + const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD); + const int bid_x = flat_offset / grid_dim_y; + const int bid_y = flat_offset % grid_dim_y; + + unswizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + template __global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { // Find tensor corresponding to block @@ -647,6 +902,89 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void launch_multi_tensor_unswizzle_scaling_factors(MultiSwizzleArgs& kernel_args, + const int vec_load_size, const bool is_rowwise, + cudaStream_t stream) { + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + + int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + if (num_blocks > 0) { + dim3 block_size(TB_DIM, TB_DIM); + if (is_rowwise) { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_unswizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_unswizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} + void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); @@ -816,28 +1154,436 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args, vec_load_size, false, stream); } } -} // namespace transformer_engine -/* - * WIP (Phuong): - * - Opt for bank conflicts - * - Adding swizzle for 2d-block scaling. - */ -void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_swizzle_scaling_factors); - using namespace transformer_engine; - swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); -} +void unswizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { + const auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); -void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, - const size_t num_tensors, cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors); - using namespace transformer_engine; - NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); - std::vector input_list, output_list; - for (size_t i = 0; i < num_tensors; i++) { - input_list.push_back(convertNVTETensorCheck(inputs[i])); - output_list.push_back(convertNVTETensorCheck(outputs[i])); + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + NVTE_CHECK(input->with_gemm_swizzled_scales, "Expected input tensor with swizzled scales."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "Expected output tensor in row-major compact format."); + + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", + to_string(input->dtype()), ")."); + break; + case NVTE_NVFP4_1D_SCALING: + NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ", + to_string(input->dtype()), ")."); + break; + default: + NVTE_ERROR("Invalid scaling mode"); } - multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); + + const bool has_rowwise_scale_inv = input->scale_inv.has_data(); + const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); + NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, + "Input tensor has both row-wise and column-wise scaling factors"); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } + + int m{0}, k{0}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, + "."); + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, + "."); + m = input->columnwise_scale_inv.shape[0]; + k = input->columnwise_scale_inv.shape[1]; + } + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); + } + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + + if (has_rowwise_scale_inv) { + NVTE_CHECK(output->scale_inv.has_data(), + "Output tensor does not have row-wise scaling factors."); + } + if (has_columnwise_scale_inv) { + NVTE_CHECK(output->columnwise_scale_inv.has_data(), + "Output tensor does not have column-wise scaling factors."); + } + + bool rowwise_unswizzle{false}, columnwise_unswizzle{false}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + rowwise_unswizzle = has_rowwise_scale_inv; + columnwise_unswizzle = has_columnwise_scale_inv; + break; + } + case NVTE_NVFP4_1D_SCALING: { + rowwise_unswizzle = true; + columnwise_unswizzle = false; + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); + } + + const dim3 block_size(TB_DIM, TB_DIM); + const int num_tiles_m = m / SF_TILE_DIM_M; + const int num_tiles_k = k / SF_TILE_DIM_K; + + if (rowwise_unswizzle) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + int original_M{0}, original_K{0}; + void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + original_M = output->flat_first_dim(); + original_K = output->flat_last_dim() / MXFP8_BLOCK_SIZE; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + NVTE_CHECK(static_cast(original_M) * original_K == output->scale_inv.numel(), + "Expected output tensor to have ", static_cast(original_M) * original_K, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + break; + } + case NVTE_NVFP4_1D_SCALING: { + if (has_rowwise_scale_inv) { + original_M = output->flat_first_dim(); + original_K = output->flat_last_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + NVTE_CHECK(static_cast(original_M) * original_K == output->scale_inv.numel(), + "Expected output tensor to have ", + static_cast(original_M) * original_K, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + } else if (has_columnwise_scale_inv) { + original_M = output->flat_last_dim(); + original_K = output->flat_first_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->columnwise_scale_inv.dptr; + output_scale_inv_ptr = output->columnwise_scale_inv.dptr; + NVTE_CHECK( + static_cast(original_M) * original_K == output->columnwise_scale_inv.numel(), + "Expected output tensor to have ", static_cast(original_M) * original_K, + " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, + "."); + } + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); + } + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K, true); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K, true); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K, true); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + if (columnwise_unswizzle) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = output->flat_last_dim(); + const int original_K = output->flat_first_dim() / MXFP8_BLOCK_SIZE; + NVTE_CHECK(static_cast(original_M) * original_K == output->columnwise_scale_inv.numel(), + "Expected output tensor to have ", static_cast(original_M) * original_K, + " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, + "."); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K, false); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K, false); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(unswizzle_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + unswizzle_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K, false); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} + +void multi_tensor_unswizzle_scaling_factors(const std::vector& input, + const std::vector& output, + cudaStream_t stream) { + size_t num_tensors = input.size(); + const auto& first_scaling_mode = input[0]->scaling_mode; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + + bool all_has_data = true; + bool all_has_columnwise_data = true; + bool all_nvfp4 = true; + for (size_t i = 0; i < num_tensors; i++) { + const auto scaling_mode = input[i]->scaling_mode; + const auto is_fp8 = is_fp8_dtype(input[i]->dtype()); + const auto is_fp4 = is_fp4_dtype(input[i]->dtype()); + + NVTE_CHECK(scaling_mode == first_scaling_mode, + "All tensors should have the same scaling mode in multi-tensor unswizzle."); + NVTE_CHECK( + (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), + "Not implemented scaling mode " + to_string(scaling_mode) + "."); + NVTE_CHECK(input[i]->with_gemm_swizzled_scales, + "Expected input tensors with scales in GEMM swizzled format."); + NVTE_CHECK(!output[i]->with_gemm_swizzled_scales, + "Expected output tensors with scales in compact format."); + NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty."); + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + + all_has_data = all_has_data && input[i]->scale_inv.has_data(); + all_has_columnwise_data = + (all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data()); + all_nvfp4 = all_nvfp4 && is_nvfp4_scaling(scaling_mode); + } + NVTE_CHECK(all_has_data || all_has_columnwise_data, + "All tensors should have data or columnwise data."); + NVTE_CHECK(!all_has_data || !all_has_columnwise_data, + "All tensors have both data and columnwise data."); + + const bool rowwise_swizzle = all_has_data || all_nvfp4; + const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; + + if (rowwise_swizzle) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + int m, k; + if (all_has_data) { + m = input[i]->scale_inv.shape[0]; + k = input[i]->scale_inv.shape[1]; + } else { + NVTE_CHECK(all_nvfp4, + "When doing rowwise unswizzle with columnwise data, it has to be NVFP4"); + m = input[i]->columnwise_scale_inv.shape[0]; + k = input[i]->columnwise_scale_inv.shape[1]; + } + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + + if (all_has_data) { + NVTE_CHECK(output[i]->scale_inv.has_data(), "Output tensor ", i, + " does not have row-wise scaling factors."); + NVTE_CHECK(m * k == output[i]->scale_inv.numel(), "Expected output tensor ", i, " to have ", + m * k, " row-wise scaling factors, but got shape=", output[i]->scale_inv.shape, + "."); + } + if (all_has_columnwise_data) { + NVTE_CHECK(output[i]->columnwise_scale_inv.has_data(), "Output tensor ", i, + " does not have column-wise scaling factors."); + NVTE_CHECK(m * k == output[i]->columnwise_scale_inv.numel(), "Expected output tensor ", i, + " to have ", m * k, " column-wise scaling factors, but got shape=", + output[i]->columnwise_scale_inv.shape, "."); + } + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + if (!all_nvfp4 || all_has_data) { + int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.original_m_list[pos] = output[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = output[i]->flat_last_dim() / block_scale_size; + } else { + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.original_m_list[pos] = output[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = output[i]->flat_first_dim() / NVFP4_BLOCK_SIZE; + } + kernel_args.num_tensors++; + } + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + } + + if (columnwise_swizzle) { + NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); + + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->columnwise_scale_inv.shape[1]; + const int k = input[i]->columnwise_scale_inv.shape[0]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = output[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = output[i]->flat_first_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_unswizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + } +} +} // namespace transformer_engine + +/* + * WIP (Phuong): + * - Opt for bank conflicts + * - Adding swizzle for 2d-block scaling. + */ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_scaling_factors); + using namespace transformer_engine; + swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); +} + +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); +} + +void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_unswizzle_scaling_factors); + using namespace transformer_engine; + unswizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); +} + +void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_unswizzle_scaling_factors); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_unswizzle_scaling_factors(input_list, output_list, stream); }