Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/hipify_torch"]
path = third_party/hipify_torch
url = https://github.com/ROCmSoftwarePlatform/hipify_torch
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ if(USE_CUDA AND USE_ROCM)
endif()

if(USE_ROCM)
enable_language(HIP)
# Find the HIP package, set the HIP paths, load the HIP CMake.
include(cmake/LoadHIP.cmake)
if(NOT PYTORCH_FOUND_HIP)
set(USE_ROCM OFF)
#set(USE_ROCM OFF)
endif()
endif()

Expand Down
50 changes: 47 additions & 3 deletions src/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
################################################################################
# libtorchaudio
################################################################################
if(USE_ROCM)
list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm)
FIND_PACKAGE(HIP REQUIRED)
MESSAGE(STATUS "hip found ${ROCM_FOUND}")

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/third_party/hipify_torch/cmake")
include(Hipify)

set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE})
set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE})
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
list( APPEND CMAKE_INSTALL_RPATH "/opt/rocm/llvm/lib" )

endif()

set(
sources
lfilter.cpp
Expand Down Expand Up @@ -32,6 +47,19 @@ if(BUILD_RNNT)
rnnt/gpu/compute.cu
)
endif()
if (USE_ROCM)
hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/libtorchaudio/rnnt/gpu)
if ( NOT HIP_ADD_LIBRARY_FOUND )
list(APPEND CMAKE_MODULE_PATH /opt/rocm/hip/cmake)
find_package(HIP REQUIRED)
endif()

list(
APPEND
sources
rnnt/gpu/compute.hip
)
endif()
endif()

if(BUILD_ALIGN)
Expand Down Expand Up @@ -64,12 +92,28 @@ if(USE_CUDA)
)
endif()

if(OpenMP_CXX_FOUND)
if(USE_ROCM)
list(
APPEND
additional_libs
OpenMP::OpenMP_CXX
additional_libs
hip::host
hip::device
)
list(
APPEND
compile_definitions
USE_ROCM
)
endif()

if(USE_CUDA)
if(OpenMP_CXX_FOUND)
list(
APPEND
additional_libs
OpenMP::OpenMP_CXX
)
endif()
endif()

#------------------------------------------------------------------------------#
Expand Down
4 changes: 4 additions & 0 deletions src/libtorchaudio/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/gpu/gpu_transducer_hip.h>
#else
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#endif

#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
Expand Down
177 changes: 177 additions & 0 deletions src/libtorchaudio/rnnt/gpu/compute.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// !!! This is a file automatically generated by hipify!!!
#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/gpu/gpu_transducer_hip.h>
#else
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#endif

#include <c10/hip/HIPException.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>

namespace torchaudio {
namespace rnnt {
namespace gpu {

using torch::stable::Tensor;
using torch::headeronly::ScalarType;

// Entry point into RNNT Loss
std::tuple<Tensor, Tensor> compute(
const Tensor& logits,
const Tensor& targets,
const Tensor& logit_lengths,
const Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
STD_TORCH_CHECK(logits.is_cuda(), "logits must be on CUDA");

STD_TORCH_CHECK(
targets.is_cuda() && targets.get_device_index() == logits.get_device_index(),
"logits and targets must be on the same device");
STD_TORCH_CHECK(
logit_lengths.is_cuda() && logit_lengths.get_device_index() == logits.get_device_index(),
"logits and logit_lengths must be on the same device");
STD_TORCH_CHECK(
target_lengths.is_cuda() && target_lengths.get_device_index() == logits.get_device_index(),
"logits and target_lengths must be on the same device");

STD_TORCH_CHECK(
logits.scalar_type() == ScalarType::Float || logits.scalar_type() == ScalarType::Half,
"logits must be float32 or float16 (half) type");

STD_TORCH_CHECK(targets.scalar_type() == ScalarType::Int, "targets must be int32 type");

STD_TORCH_CHECK(
logit_lengths.scalar_type() == ScalarType::Int,
"logit_lengths must be int32 type");
STD_TORCH_CHECK(
target_lengths.scalar_type() == ScalarType::Int,
"target_lengths must be int32 type");

STD_TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
STD_TORCH_CHECK(
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
STD_TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");

STD_TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
STD_TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
STD_TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
STD_TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");

STD_TORCH_CHECK(
logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
STD_TORCH_CHECK(
target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
STD_TORCH_CHECK(
targets.size(0) == logits.size(0),
"batch dimension mismatch between logits and targets");

STD_TORCH_CHECK(
blank >= 0 && blank < logits.size(-1),
"blank must be within [0, logits.shape[-1])");

auto max_ivalue = [](const Tensor& t) {
int32_t value;
C10_HIP_CHECK(hipMemcpy(&value, torch::stable::amax(t, {}).data_ptr(), sizeof(int32_t), hipMemcpyDeviceToHost));
return value;
};

STD_TORCH_CHECK(
logits.size(1) == max_ivalue(logit_lengths),
"input length mismatch");
STD_TORCH_CHECK(
logits.size(2) == max_ivalue(target_lengths) + 1,
"output length mismatch");
STD_TORCH_CHECK(
targets.size(1) + 1 == logits.size(2),
"target length mismatch");

Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
options.stream_ = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
hipSetDevice(logits.get_device());
options.device_ = GPU;

Tensor costs = torch::stable::new_empty(logits, {options.batchSize_ * options.nHypos_});
Tensor gradients = torch::stable::empty_like(logits);
torch::stable::fill_(gradients, 0.0);

Tensor int_workspace = torch::stable::new_empty(logits, {IntWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Int);
Tensor float_workspace = torch::stable::new_empty(logits, {DtypeWorkspace<float>::ComputeSizeFromOptions(options)}, ScalarType::Float);

Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/reinterpret_cast<float*>(float_workspace.data_ptr()),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()),
/*int_size=*/int_workspace.numel());

switch (logits.scalar_type()) {
case ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/reinterpret_cast<float*>(logits.data_ptr()),
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
/*costs=*/reinterpret_cast<float*>(costs.data_ptr()),
/*gradients=*/reinterpret_cast<float*>(gradients.data_ptr()));
break;
}
case ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/reinterpret_cast<c10::Half*>(logits.data_ptr()),
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
/*costs=*/reinterpret_cast<c10::Half*>(costs.data_ptr()),
/*gradients=*/reinterpret_cast<c10::Half*>(gradients.data_ptr()));
break;
}
default: {
STD_TORCH_CHECK(false, "unreachable");
}
};

return std::make_tuple(costs, gradients);
}

void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
STD_TORCH_CHECK(num_args == 7, "num_args must be 7");
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
std::tuple<Tensor, Tensor> res = compute(
/*logits*/torch::stable::detail::to<Tensor>(stack[0]),
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])),
/*clamp*/torch::stable::detail::to<double>(stack[5]),
/*fused_log_softmax*/torch::stable::detail::to<bool>(stack[6]));
stack[0] = torch::stable::detail::from(std::get<0>(res));
stack[1] = torch::stable::detail::from(std::get<1>(res));
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_forward", &boxed_rnnt_loss);
}

} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
4 changes: 4 additions & 0 deletions src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

#ifdef USE_CUDA

#ifdef __HIP_PLATFORM_AMD__
#include <libtorchaudio/rnnt/gpu/math_hip.cuh>
#else
#include <libtorchaudio/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
Loading
Loading