From 52ece47d351cdc5ccf321188998b0da53f0f2b4e Mon Sep 17 00:00:00 2001 From: skishore Date: Thu, 6 Nov 2025 13:41:17 +0000 Subject: [PATCH] hipify rnnt loss code, include corrected header files, add hipify pytorch as a submodule --- .gitmodules | 3 + CMakeLists.txt | 3 +- src/libtorchaudio/CMakeLists.txt | 50 +- src/libtorchaudio/rnnt/gpu/compute.cu | 4 + src/libtorchaudio/rnnt/gpu/compute.hip | 177 +++++++ .../rnnt/gpu/gpu_kernel_utils.cuh | 4 + .../rnnt/gpu/gpu_kernel_utils_hip.cuh | 112 +++++ src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh | 6 + .../rnnt/gpu/gpu_kernels_hip.cuh | 446 ++++++++++++++++++ src/libtorchaudio/rnnt/gpu/gpu_transducer.h | 5 + .../rnnt/gpu/gpu_transducer_hip.h | 402 ++++++++++++++++ src/libtorchaudio/rnnt/gpu/kernel_utils.h | 4 + src/libtorchaudio/rnnt/gpu/kernels.h | 5 + src/libtorchaudio/rnnt/gpu/math_hip.cuh | 49 ++ src/libtorchaudio/rnnt/macros.h | 8 + src/libtorchaudio/rnnt/options.h | 9 +- src/libtorchaudio/rnnt/workspace.h | 16 +- third_party/hipify_torch | 1 + 18 files changed, 1296 insertions(+), 8 deletions(-) create mode 100644 src/libtorchaudio/rnnt/gpu/compute.hip create mode 100644 src/libtorchaudio/rnnt/gpu/gpu_kernel_utils_hip.cuh create mode 100644 src/libtorchaudio/rnnt/gpu/gpu_kernels_hip.cuh create mode 100644 src/libtorchaudio/rnnt/gpu/gpu_transducer_hip.h create mode 100644 src/libtorchaudio/rnnt/gpu/math_hip.cuh create mode 160000 third_party/hipify_torch diff --git a/.gitmodules b/.gitmodules index e69de29bb2..25d307cea8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/hipify_torch"] + path = third_party/hipify_torch + url = https://github.com/ROCmSoftwarePlatform/hipify_torch diff --git a/CMakeLists.txt b/CMakeLists.txt index ca35034da1..9c4f50c5a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,10 +68,11 @@ if(USE_CUDA AND USE_ROCM) endif() if(USE_ROCM) + enable_language(HIP) # Find the HIP package, set the HIP paths, load the HIP CMake. include(cmake/LoadHIP.cmake) if(NOT PYTORCH_FOUND_HIP) - set(USE_ROCM OFF) + #set(USE_ROCM OFF) endif() endif() diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 063aa93e34..751a02e106 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -1,6 +1,21 @@ ################################################################################ # libtorchaudio ################################################################################ +if(USE_ROCM) + list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) + FIND_PACKAGE(HIP REQUIRED) + MESSAGE(STATUS "hip found ${ROCM_FOUND}") + + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/third_party/hipify_torch/cmake") + include(Hipify) + + set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) + list( APPEND CMAKE_INSTALL_RPATH "/opt/rocm/llvm/lib" ) + +endif() + set( sources lfilter.cpp @@ -32,6 +47,19 @@ if(BUILD_RNNT) rnnt/gpu/compute.cu ) endif() + if (USE_ROCM) + hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/libtorchaudio/rnnt/gpu) + if ( NOT HIP_ADD_LIBRARY_FOUND ) + list(APPEND CMAKE_MODULE_PATH /opt/rocm/hip/cmake) + find_package(HIP REQUIRED) + endif() + + list( + APPEND + sources + rnnt/gpu/compute.hip + ) + endif() endif() if(BUILD_ALIGN) @@ -64,12 +92,28 @@ if(USE_CUDA) ) endif() -if(OpenMP_CXX_FOUND) +if(USE_ROCM) list( APPEND - additional_libs - OpenMP::OpenMP_CXX + additional_libs + hip::host + hip::device ) + list( + APPEND + compile_definitions + USE_ROCM + ) +endif() + +if(USE_CUDA) + if(OpenMP_CXX_FOUND) + list( + APPEND + additional_libs + OpenMP::OpenMP_CXX + ) + endif() endif() #------------------------------------------------------------------------------# diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 336e3b8abd..2d964cd9df 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -1,4 +1,8 @@ +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif #include #include diff --git a/src/libtorchaudio/rnnt/gpu/compute.hip b/src/libtorchaudio/rnnt/gpu/compute.hip new file mode 100644 index 0000000000..d851f0f467 --- /dev/null +++ b/src/libtorchaudio/rnnt/gpu/compute.hip @@ -0,0 +1,177 @@ +// !!! This is a file automatically generated by hipify!!! +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif + +#include +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +using torch::stable::Tensor; +using torch::headeronly::ScalarType; + +// Entry point into RNNT Loss +std::tuple compute( + const Tensor& logits, + const Tensor& targets, + const Tensor& logit_lengths, + const Tensor& target_lengths, + int64_t blank, + double clamp, + bool fused_log_softmax = true) { + STD_TORCH_CHECK(logits.is_cuda(), "logits must be on CUDA"); + + STD_TORCH_CHECK( + targets.is_cuda() && targets.get_device_index() == logits.get_device_index(), + "logits and targets must be on the same device"); + STD_TORCH_CHECK( + logit_lengths.is_cuda() && logit_lengths.get_device_index() == logits.get_device_index(), + "logits and logit_lengths must be on the same device"); + STD_TORCH_CHECK( + target_lengths.is_cuda() && target_lengths.get_device_index() == logits.get_device_index(), + "logits and target_lengths must be on the same device"); + + STD_TORCH_CHECK( + logits.scalar_type() == ScalarType::Float || logits.scalar_type() == ScalarType::Half, + "logits must be float32 or float16 (half) type"); + + STD_TORCH_CHECK(targets.scalar_type() == ScalarType::Int, "targets must be int32 type"); + + STD_TORCH_CHECK( + logit_lengths.scalar_type() == ScalarType::Int, + "logit_lengths must be int32 type"); + STD_TORCH_CHECK( + target_lengths.scalar_type() == ScalarType::Int, + "target_lengths must be int32 type"); + + STD_TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + STD_TORCH_CHECK( + logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); + STD_TORCH_CHECK( + target_lengths.is_contiguous(), "target_lengths must be contiguous"); + + STD_TORCH_CHECK( + logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); + STD_TORCH_CHECK( + targets.dim() == 2, "targets must be 2-D (batch, max target length)"); + STD_TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); + STD_TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); + + STD_TORCH_CHECK( + logit_lengths.size(0) == logits.size(0), + "batch dimension mismatch between logits and logit_lengths"); + STD_TORCH_CHECK( + target_lengths.size(0) == logits.size(0), + "batch dimension mismatch between logits and target_lengths"); + STD_TORCH_CHECK( + targets.size(0) == logits.size(0), + "batch dimension mismatch between logits and targets"); + + STD_TORCH_CHECK( + blank >= 0 && blank < logits.size(-1), + "blank must be within [0, logits.shape[-1])"); + + auto max_ivalue = [](const Tensor& t) { + int32_t value; + C10_HIP_CHECK(hipMemcpy(&value, torch::stable::amax(t, {}).data_ptr(), sizeof(int32_t), hipMemcpyDeviceToHost)); + return value; + }; + + STD_TORCH_CHECK( + logits.size(1) == max_ivalue(logit_lengths), + "input length mismatch"); + STD_TORCH_CHECK( + logits.size(2) == max_ivalue(target_lengths) + 1, + "output length mismatch"); + STD_TORCH_CHECK( + targets.size(1) + 1 == logits.size(2), + "target length mismatch"); + + Options options; + options.batchSize_ = logit_lengths.size(0); + options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + options.fusedLogSmax_ = fused_log_softmax; + options.stream_ = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + hipSetDevice(logits.get_device()); + options.device_ = GPU; + + Tensor costs = torch::stable::new_empty(logits, {options.batchSize_ * options.nHypos_}); + Tensor gradients = torch::stable::empty_like(logits); + torch::stable::fill_(gradients, 0.0); + + Tensor int_workspace = torch::stable::new_empty(logits, {IntWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Int); + Tensor float_workspace = torch::stable::new_empty(logits, {DtypeWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Float); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/reinterpret_cast(float_workspace.data_ptr()), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/reinterpret_cast(int_workspace.data_ptr()), + /*int_size=*/int_workspace.numel()); + + switch (logits.scalar_type()) { + case ScalarType::Float: { + Compute( + /*workspace=*/workspace, + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr())); + break; + } + case ScalarType::Half: { + Compute( + /*workspace=*/workspace, + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr())); + break; + } + default: { + STD_TORCH_CHECK(false, "unreachable"); + } + }; + + return std::make_tuple(costs, gradients); +} + +void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + STD_TORCH_CHECK(num_args == 7, "num_args must be 7"); + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); + std::tuple res = compute( + /*logits*/torch::stable::detail::to(stack[0]), + /*targets*/torch::stable::detail::to(stack[1]), + /*logit_lengths*/torch::stable::detail::to(stack[2]), + /*target_lengths*/torch::stable::detail::to(stack[3]), + /*blank*/float(torch::stable::detail::to(stack[4])), + /*clamp*/torch::stable::detail::to(stack[5]), + /*fused_log_softmax*/torch::stable::detail::to(stack[6])); + stack[0] = torch::stable::detail::from(std::get<0>(res)); + stack[1] = torch::stable::detail::from(std::get<1>(res)); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_forward", &boxed_rnnt_loss); +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh index f4ad3add2b..4d97e03881 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh @@ -2,7 +2,11 @@ #ifdef USE_CUDA +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils_hip.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils_hip.cuh new file mode 100644 index 0000000000..19f23743c8 --- /dev/null +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils_hip.cuh @@ -0,0 +1,112 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#pragma once + +#ifdef USE_ROCM + +#ifdef __HIP_PLATFORM_AMD__ +#include +#else +#include +#endif + +namespace torchaudio { +namespace rnnt { + +template +__global__ void ReduceMax2D( + int dim, + const DTYPE* inputs, // [N, dim] + CAST_DTYPE* outputs) { + __shared__ CAST_DTYPE shared[NUM_THREADS]; + + // each thread reduces one matrix row + int offset = blockIdx.x * dim; // [n, 0] + CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0) + for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { + CAST_DTYPE next = inputs[offset + d]; + if (next > val) { + val = next; + } + } + + shared[threadIdx.x] = val; + __syncthreads(); + + for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + if (shared[threadIdx.x + stride] > shared[threadIdx.x]) { + shared[threadIdx.x] = shared[threadIdx.x + stride]; + val = shared[threadIdx.x]; + } + } + __syncthreads(); + } + + CAST_DTYPE shf; + for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { +#ifndef USE_ROCM + shf = __shfl_down_sync(0xFFFFFFFF, val, stride); +#else + shf = __shfl_down(val, stride); +#endif + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + if (shf > val) { + val = shf; + } + } + } + + if (threadIdx.x == 0) { + outputs[blockIdx.x] = val; + } +} + +template +__global__ void ReduceLogSumExpGivenMax2D( + int dim, + const DTYPE* inputs, // [N, dim] + CAST_DTYPE* outputs) { // in: max -> out: logsum + + __shared__ CAST_DTYPE shared[NUM_THREADS]; + + CAST_DTYPE max = outputs[blockIdx.x]; + CAST_DTYPE val = 0; + + int offset = blockIdx.x * dim; + for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { + val = val + ::exp(CAST_DTYPE(inputs[offset + d]) - max); + } + + shared[threadIdx.x] = val; + __syncthreads(); + + for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + val = shared[threadIdx.x] + shared[threadIdx.x + stride]; + shared[threadIdx.x] = val; + } + __syncthreads(); + } + + CAST_DTYPE shf; + for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { +#ifndef USE_ROCM + shf = __shfl_down_sync(0xFFFFFFFF, val, stride); +#else + shf = __shfl_down(val, stride); +#endif + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + val = val + shf; + } + } + + if (threadIdx.x == 0) { + outputs[blockIdx.x] = max + ::log(val); + } +} + +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_ROCM diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh index 136e6844f2..fe3adc4115 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh @@ -4,9 +4,15 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#include +#else #include #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernels_hip.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernels_hip.cuh new file mode 100644 index 0000000000..d4c2e6b776 --- /dev/null +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernels_hip.cuh @@ -0,0 +1,446 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#pragma once + +#ifdef USE_ROCM + +#include + +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#include +#else +#include +#include +#include +#endif + +namespace torchaudio { +namespace rnnt { + +template +__global__ void ComputeLogProbs( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + CAST_DTYPE* logProbs, + int H = 1, + bool fusedLogSmax = true) { + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + const int& D = numTargets; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = blockIdx.x * blockDim.x + threadIdx.x; + const int u = blockIdx.y; + + if (t >= T || u >= U) { // out of boundary. + return; + } + + Indexer3D indexer(maxT, maxU); + + int idx = indexer(bTgt, t, u); + + // skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u). + logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = + CAST_DTYPE(logits[idx * D + blank]) - denominators[idx]; + + if (!fusedLogSmax) { + logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = + CAST_DTYPE(logits[idx * D + blank]); + } + + if (u < U - 1) { + // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, + // u). + int target = targets[Indexer2D(maxU - 1)(bTgt, u)]; + logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = + CAST_DTYPE(logits[idx * D + target]) - denominators[idx]; + + if (!fusedLogSmax) { + logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = + CAST_DTYPE(logits[idx * D + target]); + } + } +} + +template +__device__ void ComputeAlphas( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int H = 1) { + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = blockIdx.x * blockDim.x + threadIdx.x + 1; + const int u = blockIdx.y + 1; + + if (t >= T || u >= U) { // out of boundary. + return; + } + + int* counter = alpha_counters + Indexer2D(maxU)(bTgt, blockIdx.y); + + Indexer3D idxr(maxT, maxU); + + if (t == 1 && u == 1) { + alphas[idxr(bTgt, 0, 0)] = 0; + } + + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) { + } + } + + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) { + } + } + + if (t == 1 && u < U) { + // alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit(). + alphas[idxr(bTgt, 0, u)] = alphas[idxr(bTgt, 0, u - 1)] + + logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; + } + + if (blockIdx.y == 0 && t < T) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE val; + +#pragma unroll + for (int i = 1; i < warpSize; i <<= 1) { +#ifndef USE_ROCM + val = __shfl_up_sync(0xffffffff, skip_prob, i); +#else + val = __shfl_up(skip_prob, i); +#endif + if (i <= threadIdx.x) { + skip_prob = skip_prob + val; + } + } + + val = alphas[idxr(bTgt, blockIdx.x * blockDim.x, 0)]; + alphas[idxr(bTgt, t, 0)] = skip_prob + val; + } + + if (t < T && u < U) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = + logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; + + CAST_DTYPE skip = + alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob; + CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob; + + CAST_DTYPE val = math::lse(skip, emit); + CAST_DTYPE out = val; + + for (int i = 1; i < warpSize; ++i) { +#ifndef USE_ROCM + val = __shfl_up_sync(0xffffffff, val, 1); +#else + val = __shfl_up(val, 1); +#endif + if (i == threadIdx.x) { + val = math::lse(val + skip_prob, emit); + out = val; + } + } + + alphas[idxr(bTgt, t, u)] = out; + } + + if (threadIdx.x == 0) { + __threadfence(); + atomicAdd(counter, 1); + } +} + +template +__device__ void ComputeBetasCosts( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int H = 1) { + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = T - 2 - blockIdx.x * blockDim.x - threadIdx.x; + const int u = U - 2 - blockIdx.y; + + if (t < 0 || u < 0) { // out of boundary. + return; + } + + int* counter = betaCounters + Indexer2D(maxU)(bTgt, blockIdx.y); + + Indexer3D idxr(maxT, maxU); + + if (t == T - 2 && u == U - 2) { + betas[idxr(bTgt, T - 1, U - 1)] = + logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; + } + + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) { + } + } + + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) { + } + } + + if (t == T - 2 && u >= 0) { + betas[idxr(bTgt, T - 1, u)] = betas[idxr(bTgt, T - 1, u + 1)] + + logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX]; + } + + if (blockIdx.y == 0 && t >= 0) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE val; + +#pragma unroll + for (int i = 1; i < warpSize; i <<= 1) { +#ifndef USE_ROCM + val = __shfl_up_sync(0xffffffff, skip_prob, i); +#else + val = __shfl_up(skip_prob, i); +#endif + if (i <= threadIdx.x) { + skip_prob = skip_prob + val; + } + } + + betas[idxr(bTgt, t, U - 1)] = + betas[idxr(bTgt, T - 1 - blockIdx.x * blockDim.x, U - 1)] + skip_prob; + } + + if (t >= 0 && u >= 0) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = + logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX]; + + CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob; + CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob; + + CAST_DTYPE val = math::lse(skip, emit); + CAST_DTYPE out = val; + + for (int i = 1; i < warpSize; ++i) { +#ifndef USE_ROCM + val = __shfl_up_sync(0xffffffff, val, 1); +#else + val = __shfl_up(val, 1); +#endif + if (i == threadIdx.x) { + val = math::lse(val + skip_prob, emit); + out = val; + } + } + + betas[idxr(bTgt, t, u)] = out; + + if (t == 0 && u == 0) { // use -beta(0, 0) as cost. + costs[bTgt] = DTYPE(-out); + } + } + + if (threadIdx.x == 0) { + __threadfence(); + atomicAdd(counter, 1); + } +} + +template +__global__ void ComputeAlphasBetasCosts( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int warpSize = 0, + int numWarps = 0, + int H = 1) { + assert(threadIdx.y == 0 || threadIdx.y == 1); + + if (threadIdx.y == 0) { + ComputeAlphas( + /*maxSrcLen=*/maxSrcLen, + /*maxTgtLen=*/maxTgtLen, + /*numTargets=*/numTargets, + /*blank=*/blank, + /*logProbs=*/logProbs, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/alpha_counters, + /*alphas=*/alphas, + H); + } else { // threadIdx.y == 1 + ComputeBetasCosts( + /*maxSrcLen=*/maxSrcLen, + /*maxTgtLen=*/maxTgtLen, + /*numTargets=*/numTargets, + /*blank=*/blank, + /*logProbs=*/logProbs, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*betaCounters=*/betaCounters, + /*beta=*/betas, + /*costs=*/costs, + H); + } +} + +template +__global__ void ComputeGradients( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + CAST_DTYPE clamp, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + const CAST_DTYPE* alphas, + const CAST_DTYPE* betas, + DTYPE* gradients, + int H = 1, + bool fusedLogSmax = true) { + const int bTgt = blockIdx.z; // 0 <= b < B + const int t = blockIdx.x * blockDim.x + threadIdx.x; + const int u = blockIdx.y; + + ComputeGradientsElement( + bTgt, + t, + u, + maxSrcLen, + maxTgtLen, + numTargets, + blank, + clamp, + logits, + targets, + srcLengths, + tgtLengths, + denominators, + alphas, + betas, + gradients, + H, + fusedLogSmax); +} + +// This is a __global__ wrapper around ComputeAlphas +// device kernel to enable unit testing +template +__global__ void ComputeAlphasWrapper( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int H = 1) { + ComputeAlphas( + maxSrcLen, + maxTgtLen, + numTargets, + blank, + logProbs, + srcLengths, + tgtLengths, + alpha_counters, + alphas, + H); +} + +// This is a __global__ wrapper around ComputeBetas +// device kernel to enable unit testing +template +__global__ void ComputeBetasWrapper( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int H = 1) { + ComputeBetasCosts( + maxSrcLen, + maxTgtLen, + numTargets, + blank, + logProbs, + srcLengths, + tgtLengths, + betaCounters, + betas, + costs, + H); +} + +// #undef LOG_PROBS_SKIP_IDX +// #undef LOG_PROBS_EMIT_IDX + +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_ROCM diff --git a/src/libtorchaudio/rnnt/gpu/gpu_transducer.h b/src/libtorchaudio/rnnt/gpu/gpu_transducer.h index 875c47974f..2b2c13d431 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_transducer.h +++ b/src/libtorchaudio/rnnt/gpu/gpu_transducer.h @@ -3,8 +3,13 @@ #ifdef USE_CUDA #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#else #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/gpu_transducer_hip.h b/src/libtorchaudio/rnnt/gpu/gpu_transducer_hip.h new file mode 100644 index 0000000000..20456ff3fe --- /dev/null +++ b/src/libtorchaudio/rnnt/gpu/gpu_transducer_hip.h @@ -0,0 +1,402 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#pragma once + +#ifdef USE_ROCM + +#include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#else +#include +#include +#endif + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +#define gpuErrchk(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } + +inline void gpuAssert( + hipError_t code, + const char* file, + int line, + bool abort = true) { + if (code != hipSuccess) { + fprintf( + stderr, + "\nGPUassert: %s %s %d\n", + hipGetErrorString(code), + file, + line); + if (abort) + exit(code); + } +} + +template +status_t LogSumExp2D( + hipStream_t stream, + int N, + int D, + const DTYPE* logits, // [N, D] + CAST_DTYPE* outputs) { + { // compute max among D. + dim3 block_dims(N); + dim3 thread_dims(REDUCE_THREADS); + + hipLaunchKernelGGL(( ReduceMax2D) + , dim3(block_dims), dim3(thread_dims), 0, stream, + /*dim=*/D, + /*inputs=*/logits, + /*outputs=*/outputs); + + // BUGBUG: These error codes are only accurate when launching with + // blocking. Otherwise they usually reflect earlier errors. + if (hipGetLastError() != hipSuccess) { + return COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED; + } + } + + { // compute log(sum(exp(d_i - max))) + dim3 block_dims(N); + dim3 thread_dims(REDUCE_THREADS); + + hipLaunchKernelGGL(( ReduceLogSumExpGivenMax2D) + , dim3(block_dims), dim3(thread_dims), 0, stream, + /*dim=*/D, + /*inputs=*/logits, + /*outputs=*/outputs); + + if (hipGetLastError() != hipSuccess) { + return COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED; + } + } + + return SUCCESS; +} + +// Inputs: +// workspace: workspace. +// logits: pointer to (B, max_T, max_U, D) logits. +// targets: pointer to (B, max_U - 1) targets in the batch. +// srcLengths: pointer to (B, ) source lengths in the batch. +// tgtLengths: pointer to (B, ) target lengths in the batch. +// +// Outputs: +// costs: pointer to (B, ) costs in the batch. +// gradients: pointer to (B, max_T, max_U, D) gradients in the batch. +template +status_t Compute( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* gradients = nullptr) { + const Options& options = workspace.GetOptions(); + + const hipStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + const CAST_DTYPE clamp = options.clamp_; + + const bool& fusedLogSmax = options.fusedLogSmax_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + hipLaunchKernelGGL(( ComputeLogProbs), dim3(block_dims), dim3(thread_dims), 0, stream, + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H, + fusedLogSmax); + + if (hipGetLastError() != hipSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + + { // compute alphas, betas and costs. + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B * H blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 2. 1 for alpha, 1 for beta + dim3 thread_dims(WARP_SIZE, 2); + + hipLaunchKernelGGL(( ComputeAlphasBetasCosts) + , dim3(block_dims), dim3(thread_dims), 0, stream, + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToAlphaCounters(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*beta_counters=*/workspace.GetPointerToBetaCounters(), + /*betas=*/workspace.GetPointerToBetas(), + /*costs=*/costs, + /*warp_size=*/WARP_SIZE, + /*num_warps=*/num_warps, + H); + if (hipGetLastError() != hipSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + if (gradients != nullptr) { // compute gradients. + // don't set gradients to zero to here as gradients might reuse memory from + // logits + + int num_blocks = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_blocks, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + hipLaunchKernelGGL(( ComputeGradients), dim3(block_dims), dim3(thread_dims), 0, stream, + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*clamp=*/clamp, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*betas=*/workspace.GetPointerToBetas(), + /*gradients=*/gradients, + H, + fusedLogSmax); + if (hipGetLastError() != hipSuccess) { + return COMPUTE_GRADIENTS_FAILED; + } + } + + return SUCCESS; +} + +template +status_t ComputeAlphas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* alphas) { + const Options& options = workspace.GetOptions(); + + const hipStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + hipLaunchKernelGGL(( ComputeLogProbs), dim3(block_dims), dim3(thread_dims), 0, stream, + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H); + + if (hipGetLastError() != hipSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + { // compute alphas + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 1 for alpha only + dim3 thread_dims(WARP_SIZE, 1); + + hipLaunchKernelGGL(( ComputeAlphasWrapper) + , dim3(block_dims), dim3(thread_dims), 0, stream, + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToAlphaCounters(), + /*alphas=*/(volatile DTYPE*)alphas, + H); + + if (hipGetLastError() != hipSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + return SUCCESS; +} + +template +status_t ComputeBetas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* betas) { + const Options& options = workspace.GetOptions(); + + const hipStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + hipLaunchKernelGGL(( ComputeLogProbs), dim3(block_dims), dim3(thread_dims), 0, stream, + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H); + + if (hipGetLastError() != hipSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + { // compute betas + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 1 for betas only + dim3 thread_dims(WARP_SIZE, 1); + + hipLaunchKernelGGL(( ComputeBetasWrapper) + , dim3(block_dims), dim3(thread_dims), 0, stream, + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToBetaCounters(), + /*alphas=*/(volatile DTYPE*)betas, + costs, + H); + + if (hipGetLastError() != hipSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + return SUCCESS; +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_ROCM diff --git a/src/libtorchaudio/rnnt/gpu/kernel_utils.h b/src/libtorchaudio/rnnt/gpu/kernel_utils.h index 9cfaf42cdd..da0b99e30c 100644 --- a/src/libtorchaudio/rnnt/gpu/kernel_utils.h +++ b/src/libtorchaudio/rnnt/gpu/kernel_utils.h @@ -2,7 +2,11 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/kernels.h b/src/libtorchaudio/rnnt/gpu/kernels.h index 5f327d3ee3..46980394f8 100644 --- a/src/libtorchaudio/rnnt/gpu/kernels.h +++ b/src/libtorchaudio/rnnt/gpu/kernels.h @@ -2,8 +2,13 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#else #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/src/libtorchaudio/rnnt/gpu/math_hip.cuh b/src/libtorchaudio/rnnt/gpu/math_hip.cuh new file mode 100644 index 0000000000..963d741e67 --- /dev/null +++ b/src/libtorchaudio/rnnt/gpu/math_hip.cuh @@ -0,0 +1,49 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#ifdef USE_ROCM + +#include + +#endif // USE_ROCM + +#include + +namespace torchaudio { +namespace rnnt { + +namespace math { + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { + if (x > y) + return x; + else + return y; +} + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { + if (x > y) + return y; + else + return x; +} + +// log_sum_exp +template +FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); + +template <> +FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { + if (y > x) { + return y + log1pf(expf(x - y)); + } else { + return x + log1pf(expf(y - x)); + } +} + +} // namespace math + +} // namespace rnnt +} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/macros.h b/src/libtorchaudio/rnnt/macros.h index cdc83dd5d2..f1677f9198 100644 --- a/src/libtorchaudio/rnnt/macros.h +++ b/src/libtorchaudio/rnnt/macros.h @@ -8,6 +8,14 @@ #define FORCE_INLINE __forceinline__ #include #include +#elif USE_ROCM +#define WARP_SIZE 32 +#define MAX_THREADS_PER_BLOCK 1024 +#define REDUCE_THREADS 256 +#define HOST_AND_DEVICE __host__ __device__ +#define FORCE_INLINE __forceinline__ +#include +#include #else #define HOST_AND_DEVICE #define FORCE_INLINE inline diff --git a/src/libtorchaudio/rnnt/options.h b/src/libtorchaudio/rnnt/options.h index 8a8fed1116..c3b5bdfa4d 100644 --- a/src/libtorchaudio/rnnt/options.h +++ b/src/libtorchaudio/rnnt/options.h @@ -2,7 +2,12 @@ #ifdef USE_CUDA #include +typedef cudaStream_t gpuStream_t; #endif // USE_CUDA +#ifdef USE_ROCM +#include +typedef hipStream_t gpuStream_t; +#endif // USE_ROCM #include #include @@ -13,9 +18,9 @@ namespace rnnt { struct Options { // the device to compute transducer loss. device_t device_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // the stream to launch kernels in when using GPU. - cudaStream_t stream_; + gpuStream_t stream_; #endif // The maximum number of threads that can be used. int numThreads_; diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index b4bbb30a43..0d457c5c78 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -133,10 +133,22 @@ class IntWorkspace { ComputeSizeForBetaCounters(options_) * sizeof(int)); } #endif // USE_CUDA +#ifdef USE_ROCM + if (data_ != nullptr && options_.device_ == GPU) { + hipMemset( + GetPointerToAlphaCounters(), + 0, + ComputeSizeForAlphaCounters(options_) * sizeof(int)); + hipMemset( + GetPointerToBetaCounters(), + 0, + ComputeSizeForBetaCounters(options_) * sizeof(int)); + } +#endif // USE_ROCM } static int ComputeSizeForAlphaCounters(const Options& options) { // B * U -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) if (options.device_ == GPU) { return options.BU(); } else { @@ -147,7 +159,7 @@ class IntWorkspace { #endif // USE_CUDA } static int ComputeSizeForBetaCounters(const Options& options) { // B * U -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) if (options.device_ == GPU) { return options.BU(); } else { diff --git a/third_party/hipify_torch b/third_party/hipify_torch new file mode 160000 index 0000000000..ee928d80eb --- /dev/null +++ b/third_party/hipify_torch @@ -0,0 +1 @@ +Subproject commit ee928d80eb49a74be5d556465e04c6a40de7e3bc