-
Notifications
You must be signed in to change notification settings - Fork 653
Feature/unswizzle #2732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature/unswizzle #2732
Changes from all commits
26ff5be
6a064cf
7d1567e
fd3ff05
64b86d7
621bc16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -56,6 +56,35 @@ void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output, | |||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| template <int SF_TILE_DIM_M, int SF_TILE_DIM_K, bool row_scaling> | ||||||||||||||||
| void compute_ref_unswizzle(const uint8_t *h_input, uint8_t *h_output, | ||||||||||||||||
| const size_t M, const size_t K) { | ||||||||||||||||
|
|
||||||||||||||||
| constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4; | ||||||||||||||||
| constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4; | ||||||||||||||||
| constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K; | ||||||||||||||||
|
|
||||||||||||||||
| for (int m = 0; m < M; m++) { | ||||||||||||||||
| for (int k = 0; k < K; k++) { | ||||||||||||||||
|
|
||||||||||||||||
| int tile_id_m = m / SF_TILE_DIM_M; | ||||||||||||||||
| int tile_id_k = k / SF_TILE_DIM_K; | ||||||||||||||||
| int m_in_tile = m % SF_TILE_DIM_M; | ||||||||||||||||
| int k_in_tile = k % SF_TILE_DIM_K; | ||||||||||||||||
|
|
||||||||||||||||
| int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M; | ||||||||||||||||
| int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile; | ||||||||||||||||
|
|
||||||||||||||||
| int tile_input_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE; | ||||||||||||||||
| int in_index = tile_input_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile; | ||||||||||||||||
| if constexpr(row_scaling) | ||||||||||||||||
| h_output[k + m * K] = h_input[in_index]; | ||||||||||||||||
| else | ||||||||||||||||
| h_output[k * M + m] = h_input[in_index]; | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { | ||||||||||||||||
| using namespace test; | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -110,6 +139,59 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row | |||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| void performTestUnswizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { | ||||||||||||||||
| using namespace test; | ||||||||||||||||
|
|
||||||||||||||||
| int SF_MODE_X, SF_MODE_Y; | ||||||||||||||||
| if (rowwise) { | ||||||||||||||||
| SF_MODE_X = 1; | ||||||||||||||||
| SF_MODE_Y = 32; | ||||||||||||||||
| } | ||||||||||||||||
| if (columnwise) { | ||||||||||||||||
| SF_MODE_X = 32; | ||||||||||||||||
| SF_MODE_Y = 1; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | ||||||||||||||||
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | ||||||||||||||||
| std::to_string(SF_MODE_Y) + "is not implemented."; | ||||||||||||||||
| } | ||||||||||||||||
|
Comment on lines
+155
to
+158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uninitialized variables used in skip message When The same issue exists in
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| DType dtype = DType::kFloat8E4M3; | ||||||||||||||||
|
|
||||||||||||||||
| const size_t M = num_tiles_M * MAT_TILE_DIM_M; | ||||||||||||||||
| const size_t K = num_tiles_K * MAT_TILE_DIM_K; | ||||||||||||||||
| const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M}; | ||||||||||||||||
|
|
||||||||||||||||
| const auto scale_shape = std::vector<size_t>{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; | ||||||||||||||||
|
|
||||||||||||||||
| Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); | ||||||||||||||||
| input.set_with_gemm_swizzled_scales(true); | ||||||||||||||||
| Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); | ||||||||||||||||
|
|
||||||||||||||||
| fillUniform(&input); | ||||||||||||||||
|
|
||||||||||||||||
| std::unique_ptr<uint8_t[]> ref_output = std::make_unique<uint8_t[]>(scale_shape[0] * scale_shape[1]); | ||||||||||||||||
|
|
||||||||||||||||
| nvte_unswizzle_scaling_factors(input.data(), output.data(), 0); | ||||||||||||||||
|
|
||||||||||||||||
| if (rowwise) | ||||||||||||||||
| compute_ref_unswizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0], scale_shape[1]); | ||||||||||||||||
| else | ||||||||||||||||
| compute_ref_unswizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[1], scale_shape[0]); | ||||||||||||||||
|
|
||||||||||||||||
| cudaDeviceSynchronize(); | ||||||||||||||||
| auto err = cudaGetLastError(); | ||||||||||||||||
| ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); | ||||||||||||||||
|
|
||||||||||||||||
| output.to_cpu(); | ||||||||||||||||
| if (rowwise) { | ||||||||||||||||
| compareResults("output_unswizzle", output.rowwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]); | ||||||||||||||||
| } else { | ||||||||||||||||
| compareResults("output_unswizzle", output.columnwise_cpu_scale_inv_ptr<uint8_t>(), ref_output.get(), scale_shape[0] * scale_shape[1]); | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| class SwizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {}; | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -126,6 +208,21 @@ TEST_P(SwizzleTestSuite, TestSwizzle) { | |||||||||||||||
| transa); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| class UnswizzleTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {}; | ||||||||||||||||
|
|
||||||||||||||||
| TEST_P(UnswizzleTestSuite, TestUnswizzle) { | ||||||||||||||||
| using namespace transformer_engine; | ||||||||||||||||
| using namespace test; | ||||||||||||||||
|
|
||||||||||||||||
| const auto num_tiles = std::get<0>(GetParam()); | ||||||||||||||||
| const auto scaling_mode = std::get<1>(GetParam()); | ||||||||||||||||
| const auto transa = std::get<2>(GetParam()); | ||||||||||||||||
|
|
||||||||||||||||
| performTestUnswizzle1D(num_tiles.first, num_tiles.second, | ||||||||||||||||
| scaling_mode.first, scaling_mode.second, | ||||||||||||||||
| transa); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| namespace { | ||||||||||||||||
|
|
||||||||||||||||
| std::vector<std::pair<int, int>> num_tiles = { | ||||||||||||||||
|
|
@@ -164,3 +261,105 @@ INSTANTIATE_TEST_SUITE_P( | |||||||||||||||
| std::to_string(std::get<2>(info.param)); | ||||||||||||||||
| return name; | ||||||||||||||||
| }); | ||||||||||||||||
|
|
||||||||||||||||
| INSTANTIATE_TEST_SUITE_P( | ||||||||||||||||
| OperatorTest, | ||||||||||||||||
| UnswizzleTestSuite, | ||||||||||||||||
| ::testing::Combine( | ||||||||||||||||
| ::testing::ValuesIn(num_tiles), | ||||||||||||||||
| ::testing::ValuesIn(scaling_mode), | ||||||||||||||||
| ::testing::ValuesIn(transa) | ||||||||||||||||
| ), | ||||||||||||||||
| [](const testing::TestParamInfo<UnswizzleTestSuite::ParamType>& info) { | ||||||||||||||||
| std::string name = "ntiles" + | ||||||||||||||||
| std::to_string(std::get<0>(info.param).first) + "X" + | ||||||||||||||||
| std::to_string(std::get<0>(info.param).second) + "smode" + | ||||||||||||||||
| std::to_string(std::get<1>(info.param).first) + "X"+ | ||||||||||||||||
| std::to_string(std::get<1>(info.param).second) + "trans" + | ||||||||||||||||
| std::to_string(std::get<2>(info.param)); | ||||||||||||||||
| return name; | ||||||||||||||||
| }); | ||||||||||||||||
|
|
||||||||||||||||
| void performTestSwizzleUnswizzleRoundtrip(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { | ||||||||||||||||
| using namespace test; | ||||||||||||||||
|
|
||||||||||||||||
| int SF_MODE_X, SF_MODE_Y; | ||||||||||||||||
| if (rowwise) { | ||||||||||||||||
| SF_MODE_X = 1; | ||||||||||||||||
| SF_MODE_Y = 32; | ||||||||||||||||
| } | ||||||||||||||||
| if (columnwise) { | ||||||||||||||||
| SF_MODE_X = 32; | ||||||||||||||||
| SF_MODE_Y = 1; | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if ((rowwise && columnwise) || !(rowwise || columnwise)){ | ||||||||||||||||
| GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + | ||||||||||||||||
| std::to_string(SF_MODE_Y) + "is not implemented."; | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing space in skip message (roundtrip test) Same missing space issue — produces
Suggested change
|
||||||||||||||||
| } | ||||||||||||||||
|
Comment on lines
+296
to
+299
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uninitialized variables used in skip message (roundtrip test) Same undefined-behaviour issue as in
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| DType dtype = DType::kFloat8E4M3; | ||||||||||||||||
|
|
||||||||||||||||
| const size_t M = num_tiles_M * MAT_TILE_DIM_M; | ||||||||||||||||
| const size_t K = num_tiles_K * MAT_TILE_DIM_K; | ||||||||||||||||
| const auto data_shape = transa ? std::vector<size_t>{M, K} : std::vector<size_t>{K, M}; | ||||||||||||||||
|
|
||||||||||||||||
| const auto scale_shape = std::vector<size_t>{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; | ||||||||||||||||
|
|
||||||||||||||||
| Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); | ||||||||||||||||
| Tensor swizzled("swizzled", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); | ||||||||||||||||
| swizzled.set_with_gemm_swizzled_scales(true); | ||||||||||||||||
| Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); | ||||||||||||||||
|
|
||||||||||||||||
| fillUniform(&input); | ||||||||||||||||
|
|
||||||||||||||||
| nvte_swizzle_scaling_factors(input.data(), swizzled.data(), 0); | ||||||||||||||||
| nvte_unswizzle_scaling_factors(swizzled.data(), output.data(), 0); | ||||||||||||||||
|
|
||||||||||||||||
| cudaDeviceSynchronize(); | ||||||||||||||||
| auto err = cudaGetLastError(); | ||||||||||||||||
| ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); | ||||||||||||||||
|
|
||||||||||||||||
| input.to_cpu(); | ||||||||||||||||
| output.to_cpu(); | ||||||||||||||||
| if (rowwise) { | ||||||||||||||||
| compareResults("roundtrip_rowwise", output.rowwise_cpu_scale_inv_ptr<uint8_t>(), | ||||||||||||||||
| input.rowwise_cpu_scale_inv_ptr<uint8_t>(), scale_shape[0] * scale_shape[1]); | ||||||||||||||||
| } else { | ||||||||||||||||
| compareResults("roundtrip_columnwise", output.columnwise_cpu_scale_inv_ptr<uint8_t>(), | ||||||||||||||||
| input.columnwise_cpu_scale_inv_ptr<uint8_t>(), scale_shape[0] * scale_shape[1]); | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| class SwizzleUnswizzleRoundtripTestSuite : public ::testing::TestWithParam<std::tuple<std::pair<int, int>, std::pair<bool, bool>, bool>> {}; | ||||||||||||||||
|
|
||||||||||||||||
| TEST_P(SwizzleUnswizzleRoundtripTestSuite, TestSwizzleUnswizzleRoundtrip) { | ||||||||||||||||
| using namespace transformer_engine; | ||||||||||||||||
| using namespace test; | ||||||||||||||||
|
|
||||||||||||||||
| const auto num_tiles = std::get<0>(GetParam()); | ||||||||||||||||
| const auto scaling_mode = std::get<1>(GetParam()); | ||||||||||||||||
| const auto transa = std::get<2>(GetParam()); | ||||||||||||||||
|
|
||||||||||||||||
| performTestSwizzleUnswizzleRoundtrip(num_tiles.first, num_tiles.second, | ||||||||||||||||
| scaling_mode.first, scaling_mode.second, | ||||||||||||||||
| transa); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| INSTANTIATE_TEST_SUITE_P( | ||||||||||||||||
| OperatorTest, | ||||||||||||||||
| SwizzleUnswizzleRoundtripTestSuite, | ||||||||||||||||
| ::testing::Combine( | ||||||||||||||||
| ::testing::ValuesIn(num_tiles), | ||||||||||||||||
| ::testing::ValuesIn(scaling_mode), | ||||||||||||||||
| ::testing::ValuesIn(transa) | ||||||||||||||||
| ), | ||||||||||||||||
| [](const testing::TestParamInfo<SwizzleUnswizzleRoundtripTestSuite::ParamType>& info) { | ||||||||||||||||
| std::string name = "roundtrip_ntiles" + | ||||||||||||||||
| std::to_string(std::get<0>(info.param).first) + "X" + | ||||||||||||||||
| std::to_string(std::get<0>(info.param).second) + "smode" + | ||||||||||||||||
| std::to_string(std::get<1>(info.param).first) + "X"+ | ||||||||||||||||
| std::to_string(std::get<1>(info.param).second) + "trans" + | ||||||||||||||||
| std::to_string(std::get<2>(info.param)); | ||||||||||||||||
| return name; | ||||||||||||||||
| }); | ||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space in skip message
The concatenated string produces
"...32is not implemented."(no space before"is"). Add a leading space.