Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,35 @@ void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output,
}
}

template <int SF_TILE_DIM_M, int SF_TILE_DIM_K, bool row_scaling>
void compute_ref_unswizzle(const uint8_t *h_input, uint8_t *h_output,
const size_t M, const size_t K) {

constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4;
constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4;
constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K;

for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {

int tile_id_m = m / SF_TILE_DIM_M;
int tile_id_k = k / SF_TILE_DIM_K;
int m_in_tile = m % SF_TILE_DIM_M;
int k_in_tile = k % SF_TILE_DIM_K;

int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M;
int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile;

int tile_input_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE;
int in_index = tile_input_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile;
if constexpr(row_scaling)
h_output[k + m * K] = h_input[in_index];
else
h_output[k * M + m] = h_input[in_index];
}
}
}

void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;

Expand Down Expand Up @@ -110,6 +139,59 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row
}
}

void performTestUnswizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;

int SF_MODE_X, SF_MODE_Y;
if (rowwise) {
SF_MODE_X = 1;
SF_MODE_Y = 32;
}
if (columnwise) {
SF_MODE_X = 32;
SF_MODE_Y = 1;
}

if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + "is not implemented.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space in skip message

The concatenated string produces "...32is not implemented." (no space before "is"). Add a leading space.

Suggested change
std::to_string(SF_MODE_Y) + "is not implemented.";
std::to_string(SF_MODE_Y) + " is not implemented.";

}
Comment on lines +155 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uninitialized variables used in skip message

When !(rowwise || columnwise) is true (neither flag is set), neither if (rowwise) nor if (columnwise) branch executes, leaving SF_MODE_X and SF_MODE_Y uninitialized. Passing them to std::to_string() is undefined behaviour.

The same issue exists in performTestSwizzleUnswizzleRoundtrip at line 297.

Suggested change
if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + "is not implemented.";
}
if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode is not implemented.";
}


DType dtype = DType::kFloat8E4M3;

const size_t M = num_tiles_M * MAT_TILE_DIM_M;
const size_t K = num_tiles_K * MAT_TILE_DIM_K;
const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M};

const auto scale_shape = std::vector<size_t>{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y};

Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
input.set_with_gemm_swizzled_scales(true);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);

fillUniform(&input);

std::unique_ptr<uint8_t[]> ref_output = std::make_unique<uint8_t[]>(scale_shape[0] * scale_shape[1]);

nvte_unswizzle_scaling_factors(input.data(), output.data(), 0);

if (rowwise)
compute_ref_unswizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0], scale_shape[1]);
else
compute_ref_unswizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[1], scale_shape[0]);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

output.to_cpu();
if (rowwise) {
compareResults("output_unswizzle", output.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]);
} else {
compareResults("output_unswizzle", output.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]);
}
}

class SwizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {};


Expand All @@ -126,6 +208,21 @@ TEST_P(SwizzleTestSuite, TestSwizzle) {
transa);
}

class UnswizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {};

TEST_P(UnswizzleTestSuite, TestUnswizzle) {
using namespace transformer_engine;
using namespace test;

const auto num_tiles = std::get<0>(GetParam());
const auto scaling_mode = std::get<1>(GetParam());
const auto transa = std::get<2>(GetParam());

performTestUnswizzle1D(num_tiles.first, num_tiles.second,
scaling_mode.first, scaling_mode.second,
transa);
}

namespace {

std::vector<std::pair<int, int>> num_tiles = {
Expand Down Expand Up @@ -164,3 +261,105 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<2>(info.param));
return name;
});

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
UnswizzleTestSuite,
::testing::Combine(
::testing::ValuesIn(num_tiles),
::testing::ValuesIn(scaling_mode),
::testing::ValuesIn(transa)
),
[](const testing::TestParamInfo<UnswizzleTestSuite::ParamType>& info) {
std::string name = "ntiles" +
std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "smode" +
std::to_string(std::get<1>(info.param).first) + "X"+
std::to_string(std::get<1>(info.param).second) + "trans" +
std::to_string(std::get<2>(info.param));
return name;
});

void performTestSwizzleUnswizzleRoundtrip(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) {
using namespace test;

int SF_MODE_X, SF_MODE_Y;
if (rowwise) {
SF_MODE_X = 1;
SF_MODE_Y = 32;
}
if (columnwise) {
SF_MODE_X = 32;
SF_MODE_Y = 1;
}

if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + "is not implemented.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space in skip message (roundtrip test)

Same missing space issue — produces "...32is not implemented." without a space.

Suggested change
std::to_string(SF_MODE_Y) + "is not implemented.";
std::to_string(SF_MODE_Y) + " is not implemented.";

}
Comment on lines +296 to +299
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uninitialized variables used in skip message (roundtrip test)

Same undefined-behaviour issue as in performTestUnswizzle1DSF_MODE_X and SF_MODE_Y are uninitialized when !(rowwise || columnwise).

Suggested change
if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" +
std::to_string(SF_MODE_Y) + "is not implemented.";
}
if ((rowwise && columnwise) || !(rowwise || columnwise)){
GTEST_SKIP() << "TEST SKIPPED, The scaling mode is not implemented.";
}


DType dtype = DType::kFloat8E4M3;

const size_t M = num_tiles_M * MAT_TILE_DIM_M;
const size_t K = num_tiles_K * MAT_TILE_DIM_K;
const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M};

const auto scale_shape = std::vector<size_t>{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y};

Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
Tensor swizzled("swizzled", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
swizzled.set_with_gemm_swizzled_scales(true);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);

fillUniform(&input);

nvte_swizzle_scaling_factors(input.data(), swizzled.data(), 0);
nvte_unswizzle_scaling_factors(swizzled.data(), output.data(), 0);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

input.to_cpu();
output.to_cpu();
if (rowwise) {
compareResults("roundtrip_rowwise", output.rowwise_cpu_scale_inv_ptr<uint8_t>(),
input.rowwise_cpu_scale_inv_ptr<uint8_t>(), scale_shape[0] * scale_shape[1]);
} else {
compareResults("roundtrip_columnwise", output.columnwise_cpu_scale_inv_ptr<uint8_t>(),
input.columnwise_cpu_scale_inv_ptr<uint8_t>(), scale_shape[0] * scale_shape[1]);
}
}

class SwizzleUnswizzleRoundtripTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {};

TEST_P(SwizzleUnswizzleRoundtripTestSuite, TestSwizzleUnswizzleRoundtrip) {
using namespace transformer_engine;
using namespace test;

const auto num_tiles = std::get<0>(GetParam());
const auto scaling_mode = std::get<1>(GetParam());
const auto transa = std::get<2>(GetParam());

performTestSwizzleUnswizzleRoundtrip(num_tiles.first, num_tiles.second,
scaling_mode.first, scaling_mode.second,
transa);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwizzleUnswizzleRoundtripTestSuite,
::testing::Combine(
::testing::ValuesIn(num_tiles),
::testing::ValuesIn(scaling_mode),
::testing::ValuesIn(transa)
),
[](const testing::TestParamInfo<SwizzleUnswizzleRoundtripTestSuite::ParamType>& info) {
std::string name = "roundtrip_ntiles" +
std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "smode" +
std::to_string(std::get<1>(info.param).first) + "X"+
std::to_string(std::get<1>(info.param).second) + "trans" +
std::to_string(std::get<2>(info.param));
return name;
});
28 changes: 28 additions & 0 deletions transformer_engine/common/include/transformer_engine/swizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,34 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);

/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major
*
* \param[in] input Input tensor with swizzled scale_inv.
* \param[in,out] output Output tensor which hosts non-swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major in output.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_unswizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Unswizzling scaling factors from the interleaved layout used by GEMM back to row-major
*
* \param[in] inputs Input tensors with swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts non-swizzled scale_inv.
* \param[in] num_tensors Number of input and output tensors.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major in output.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);

/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block-scaled tensor.
Expand Down
Loading