From e3ccfea338ac2d5310fa43a945ed323bf0c6aa11 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Wed, 5 Jun 2024 17:54:00 +0800 Subject: [PATCH] [runtime] add BatchTranspose cuda kernel --- .../tensor_manipulate/kernels/transpose.cu | 89 ++++++++++++++++++- .../tensor_manipulate/kernels/transpose.h | 3 + .../default/tensor_manipulate/transpose.cc | 33 ++++--- .../default/tensor_manipulate/transpose.h | 6 +- .../default/kernel/transpose_test.cc | 48 ++++++++++ 5 files changed, 164 insertions(+), 15 deletions(-) diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu index 9436bafc2..debd5cd02 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.cu @@ -16,11 +16,12 @@ //===----------------------------------------------------------------------===// #include +#include namespace brt { namespace cuda { namespace kernel { - +constexpr int32_t kMaxGridDim = 65535; template __global__ void transpose_naive_2d_kernel(const T *input, T *output, int m, int n) { @@ -40,12 +41,98 @@ void transpose_naive_2d(const T *input, T *output, int m, int n, dim3 grid, transpose_naive_2d_kernel<<>>(input, output, m, n); } +template +__global__ void batch_transpose_kernel(const int32_t total_tile_num, + const int32_t tile_num_in_dim0, + const int32_t tile_num_in_dim1, + const int32_t tile_per_sample, + const int32_t row, const int32_t col, + void *__restrict__ inp_ptr, + void *__restrict__ out_ptr) { + __shared__ T tile_in_shmem[TileSizeX][TileSizeY]; + for (int32_t i = blockIdx.x, step_tile = gridDim.x; i < total_tile_num; + i += step_tile) { + const int32_t batch_idx = i / tile_per_sample; + const int32_t remainder = i - batch_idx * tile_per_sample; + const int32_t dim0_idx = remainder / tile_num_in_dim1; + const int32_t dim1_idx = remainder - dim0_idx * tile_num_in_dim1; + + T *inp_tile_gmem = reinterpret_cast(inp_ptr); + T *out_tile_gmem = reinterpret_cast(out_ptr); + inp_tile_gmem += batch_idx * row * col + dim0_idx * TileSizeX * col + + dim1_idx * TileSizeY; + out_tile_gmem += batch_idx * row * col + dim1_idx * TileSizeY * row + + dim0_idx * TileSizeX; + + int32_t range_0 = dim0_idx < tile_num_in_dim0 - 1 + ? TileSizeX + : row - dim0_idx * TileSizeX; + int32_t range_1 = dim1_idx < tile_num_in_dim1 - 1 + ? TileSizeY + : col - dim1_idx * TileSizeY; + constexpr int32_t row_num_per_iter = BlockSize / TileSizeY; + constexpr int32_t col_num_per_iter = BlockSize / TileSizeX; + + int32_t tile_row_idx = threadIdx.x / TileSizeY; + int32_t tile_col_idx = threadIdx.x - tile_row_idx * TileSizeY; + for (int32_t j = tile_row_idx; j < range_0; j += row_num_per_iter) { + if (tile_col_idx < range_1) { + tile_in_shmem[j][tile_col_idx ^ j] = + inp_tile_gmem[j * col + tile_col_idx]; + } + } + __syncthreads(); + tile_row_idx = threadIdx.x / TileSizeX; + tile_col_idx = threadIdx.x - tile_row_idx * TileSizeX; + for (int32_t j = tile_row_idx; j < range_1; j += col_num_per_iter) { + if (tile_col_idx < range_0) { + out_tile_gmem[j * row + tile_col_idx] = + tile_in_shmem[tile_col_idx][j ^ tile_col_idx]; + } + } + __syncthreads(); + } +} + +template +void batch_transpose(int32_t batch, int32_t row, int32_t col, const T *inp_ptr, + T *out_ptr, cudaStream_t stream) { + constexpr int32_t kTileSize = 32; + + const int32_t tile_num_in_dim0 = (row - 1) / kTileSize + 1; + const int32_t tile_num_in_dim1 = (col - 1) / kTileSize + 1; + const int32_t tile_per_sample = tile_num_in_dim0 * tile_num_in_dim1; + const int32_t total_tile_num = batch * tile_per_sample; + dim3 grid(total_tile_num >= kMaxGridDim ? kMaxGridDim : total_tile_num); + if (row < 8 || col < 8) { + constexpr int32_t kBlockSize = 64; + dim3 block(kBlockSize); + batch_transpose_kernel + <<>>( + total_tile_num, tile_num_in_dim0, tile_num_in_dim1, tile_per_sample, + row, col, reinterpret_cast(const_cast(inp_ptr)), + reinterpret_cast(out_ptr)); + } else { + constexpr int32_t kBlockSize = 256; + dim3 block(kBlockSize); + batch_transpose_kernel + <<>>( + total_tile_num, tile_num_in_dim0, tile_num_in_dim1, tile_per_sample, + row, col, reinterpret_cast(const_cast(inp_ptr)), + reinterpret_cast(out_ptr)); + } +} + // instantiate template void transpose_naive_2d(const float *, float *, int, int, dim3, dim3, cudaStream_t); template void transpose_naive_2d<__half>(const __half *, __half *, int, int, dim3, dim3, cudaStream_t); +template void batch_transpose(int32_t, int32_t, int32_t, const float *, + float *, cudaStream_t); +template void batch_transpose<__half>(int32_t, int32_t, int32_t, const __half *, + __half *, cudaStream_t); } // namespace kernel } // namespace cuda } // namespace brt diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h index bc3fb9f54..3c9281688 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/kernels/transpose.h @@ -26,6 +26,9 @@ template void transpose_naive_2d(const T *input, T *output, int m, int n, dim3 grid, dim3 block, cudaStream_t stream); +template +void batch_transpose(int32_t batch, int32_t row, int32_t col, const T *inp_ptr, + T *out_ptr, cudaStream_t stream); } // namespace kernel } // namespace cuda } // namespace brt diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc index d9929e569..f5b12e329 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.cc @@ -40,11 +40,12 @@ using namespace brt::ir; namespace brt { namespace cuda { -template Transpose2D::Transpose2D(const OpAccessor &accessor) { +template +BatchTranspose::BatchTranspose(const OpAccessor &accessor) { auto shape_input = accessor.GetArgShape(0); auto shape_output = accessor.GetArgShape(1); - BRT_ENFORCE(shape_input.size() == 2); + BRT_ENFORCE((shape_input.size() == 2 || shape_input.size() == 3)); BRT_ENFORCE(shape_output == transpose::DeduceOutputShape( shape_input, accessor.GetAttrAsIntArray("permutation"))); @@ -52,18 +53,22 @@ template Transpose2D::Transpose2D(const OpAccessor &accessor) { } template -void Transpose2D::Execute(const T *input, T *output, - cudnnHandle_t /*handle*/, cudaStream_t stream) { +void BatchTranspose::Execute(const T *input, T *output, + cudnnHandle_t /*handle*/, cudaStream_t stream) { auto p = MakeCUDAGridAndBlock(input_shape[1], input_shape[0]); - kernel::transpose_naive_2d(input, output, static_cast(input_shape[0]), - static_cast(input_shape[1]), p.first, - p.second, stream); + int32_t batch = 1, m, n; + if (input_shape.size() == 2) { + m = input_shape[0], n = input_shape[1]; + } else if (input_shape.size() == 3) { + batch = input_shape[0], m = input_shape[1], n = input_shape[2]; + } + kernel::batch_transpose(batch, m, n, input, output, stream); BRT_CUDA_CHECK(cudaGetLastError()); } // instantiate -template class Transpose2D; -template class Transpose2D<__half>; +template class BatchTranspose; +template class BatchTranspose<__half>; template Transpose4D::Transpose4D(const OpAccessor &accessor) { auto shape_input = accessor.GetArgShape(0); @@ -134,8 +139,14 @@ template class Transpose4D<__half>; template TransposeImpl::TransposeImpl(const OpAccessor &accessor) { auto shape_input = accessor.GetArgShape(0); - if (shape_input.size() == 2) { - this->impl = new Transpose2D(accessor); + if (shape_input.size() == 2 || shape_input.size() == 3) { + auto permutation = accessor.GetAttrAsIntArray("permutation"); + if (permutation[permutation.size() - 2] == permutation.size() - 1 && + permutation[permutation.size() - 1] == permutation.size() - 2) { + this->impl = new BatchTranspose(accessor); + } else { + BRT_THROW("unsupported transpose"); + } } else if (shape_input.size() == 4) { this->impl = new Transpose4D(accessor); } else { diff --git a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h index ab36caccb..8d9c11883 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h +++ b/runtime/lib/backends/cuda/providers/default/tensor_manipulate/transpose.h @@ -34,11 +34,11 @@ template class TransposeBase { }; /** - * Transpose2D + * BatchTranspose */ -template class Transpose2D : public TransposeBase { +template class BatchTranspose : public TransposeBase { public: - explicit Transpose2D(const OpAccessor &accessor); + explicit BatchTranspose(const OpAccessor &accessor); virtual void Execute(const T *input, T *output, cudnnHandle_t handle, cudaStream_t stream) override; diff --git a/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc b/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc index 5ada9095e..9ab0f8f1a 100644 --- a/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc +++ b/runtime/test/backends/cuda/providers/default/kernel/transpose_test.cc @@ -63,6 +63,38 @@ static void CheckTranspose2D(T *input, T *output, free(h_output); } +template +static void CheckTranspose3D(T *input, T *output, + const std::vector &input_shape) { + T *h_input = + (T *)malloc(input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T)); + T *h_output = + (T *)malloc(input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T)); + cudaMemcpy(h_input, input, + input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T), + cudaMemcpyDeviceToHost); + cudaMemcpy(h_output, output, + input_shape[0] * input_shape[1] * input_shape[2] * sizeof(T), + cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + int B = input_shape[0]; + int m = input_shape[1]; + int n = input_shape[2]; + for (int64_t t = 0; t < B; ++t) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + int in_idx = t * m * n + i * n + j; + int out_idx = t * m * n + j * m + i; + EXPECT_EQ(h_output[out_idx], h_input[in_idx]); + } + } + } + + free(h_input); + free(h_output); +} + template static void CheckTranspose4D(T *input, T *output, const std::vector &input_shape, @@ -142,6 +174,8 @@ static void TestTranspose(std::vector shape_input, if (shape_input.size() == 2) { CheckTranspose2D(d_input, d_output, shape_input); + } else if (shape_input.size() == 3) { + CheckTranspose3D(d_input, d_output, shape_input); } else if (shape_input.size() == 4) { CheckTranspose4D(d_input, d_output, shape_input, perm); } else { @@ -150,8 +184,16 @@ static void TestTranspose(std::vector shape_input, } TEST(CUDAOpKerenlTest, TransposeOp) { + // 2D transpose TestTranspose({32, 64}, {64, 32}, {1, 0}); + TestTranspose({2, 1}, {1, 2}, {1, 0}); + TestTranspose({1007, 13}, {13, 1007}, {1, 0}); + TestTranspose({2007, 4339}, {4339, 2007}, {1, 0}); TestTranspose({1000, 512}, {512, 1000}, {1, 0}); + // 3D Batch transpose + TestTranspose({13, 789, 1234}, {13, 1234, 789}, {0, 2, 1}); + TestTranspose({65536, 32, 50}, {65536, 50, 32}, {0, 2, 1}); + TestTranspose({65536, 2, 50}, {65536, 50, 2}, {0, 2, 1}); // NCHW 2 NHWC TestTranspose({10, 20, 30, 40}, {10, 30, 40, 20}, {0, 2, 3, 1}); // NHWC 2 NCHW @@ -159,8 +201,14 @@ TEST(CUDAOpKerenlTest, TransposeOp) { } TEST(CUDAOpKerenlTest, TransposeOpFp16) { + // 2D transpose TestTranspose<__half>({32, 64}, {64, 32}, {1, 0}); + TestTranspose<__half>({2, 1}, {1, 2}, {1, 0}); + TestTranspose<__half>({1007, 13}, {13, 1007}, {1, 0}); + TestTranspose<__half>({2007, 4339}, {4339, 2007}, {1, 0}); TestTranspose<__half>({1000, 512}, {512, 1000}, {1, 0}); + // 3D Batch transpose + TestTranspose<__half>({13, 789, 1234}, {13, 1234, 789}, {0, 2, 1}); // NCHW 2 NHWC TestTranspose<__half>({10, 20, 30, 40}, {10, 30, 40, 20}, {0, 2, 3, 1}); // NHWC 2 NCHW