diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index 444a4f8f6d..c3be77a256 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -8,6 +8,10 @@ #include #include +#ifdef USE_ROCM + #include +#endif + namespace { constexpr int kNumThreads = 1024; // Number of threads to run CUDA kernel in parallel. diff --git a/src/libtorchaudio/hip_namespace_shim.h b/src/libtorchaudio/hip_namespace_shim.h new file mode 100644 index 0000000000..7d03d919a7 --- /dev/null +++ b/src/libtorchaudio/hip_namespace_shim.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace libtorchaudio::hip { + +inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA( + torch::stable::DeviceIndex device_index = -1) { + return cuda::getCurrentHIPStreamMasqueradingAsCUDA(device_index); +} + +inline void setCurrentHIPStreamMasqueradingAsCUDA( + hipStream_t stream, + torch::stable::DeviceIndex device_index = -1) { + cuda::setCurrentHIPStreamMasqueradingAsCUDA(stream, device_index); +} + +inline hipStream_t getStreamFromPoolMasqueradingAsCUDA( + const bool isHighPriority = false, + torch::stable::DeviceIndex device_index = -1) { + return cuda::getStreamFromPoolMasqueradingAsCUDA( + isHighPriority, device_index); +} + +inline void synchronize( + hipStream_t stream, + torch::stable::DeviceIndex device_index = -1) { + cuda::synchronize(stream, device_index); +} + +} // namespace libtorchaudio::hip diff --git a/src/libtorchaudio/iir_cuda.cu b/src/libtorchaudio/iir_cuda.cu index 658bca4c54..31919ee617 100644 --- a/src/libtorchaudio/iir_cuda.cu +++ b/src/libtorchaudio/iir_cuda.cu @@ -69,12 +69,12 @@ Tensor cuda_lfilter_core_loop( const dim3 blocks((N * C + threads.x - 1) / threads.x); THO_DISPATCH_V2( - in.scalar_type(), "iir_cu_loop", AT_WRAP([&] { - (iir_cu_kernel<<>>( - torchaudio::packed_accessor_size_t(in), - torchaudio::packed_accessor_size_t(a_flipped), - torchaudio::packed_accessor_size_t(padded_out))); - STD_CUDA_KERNEL_LAUNCH_CHECK(); - }), AT_FLOATING_TYPES); + in.scalar_type(), "iir_cu_loop", AT_WRAP(([&]() { + iir_cu_kernel<<>>( + torchaudio::packed_accessor_size_t(in), + torchaudio::packed_accessor_size_t(a_flipped), + torchaudio::packed_accessor_size_t(padded_out)); + STD_CUDA_KERNEL_LAUNCH_CHECK(); + })), AT_FLOATING_TYPES); return padded_out; } diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 7e99fec395..40c5945f81 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -6,6 +6,10 @@ #include #include +#ifdef USE_ROCM + #include +#endif + namespace torchaudio { namespace rnnt { namespace gpu { diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index c953277518..c3b14c38db 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -1,7 +1,7 @@ import unittest import torch -from torchaudio_unittest.common_utils import PytorchTestCase, skipIfRocm +from torchaudio_unittest.common_utils import PytorchTestCase from .functional_impl import Functional, FunctionalCPUOnly @@ -11,11 +11,10 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase): device = torch.device("cpu") @unittest.expectedFailure - @skipIfRocm def test_lfilter_9th_order_filter_stability(self): super().test_lfilter_9th_order_filter_stability() -@skipIfRocm + class TestFunctionalFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device("cpu") diff --git a/test/torchaudio_unittest/functional/functional_cuda_test.py b/test/torchaudio_unittest/functional/functional_cuda_test.py index ad74db8279..037b053099 100644 --- a/test/torchaudio_unittest/functional/functional_cuda_test.py +++ b/test/torchaudio_unittest/functional/functional_cuda_test.py @@ -1,8 +1,7 @@ import unittest import torch - -from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfSingleCuda, skipIfRocm +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfSingleCuda from .functional_impl import Functional, FunctionalCUDAOnly @@ -13,13 +12,11 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): device = torch.device("cuda") @unittest.expectedFailure - @skipIfRocm def test_lfilter_9th_order_filter_stability(self): super().test_lfilter_9th_order_filter_stability() @skipIfNoCuda -@skipIfRocm class TestLFilterFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device("cuda") diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 50e403fd34..e2e42f1fe7 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -14,7 +14,6 @@ get_whitenoise, nested_params, rnnt_utils, - skipIfRocm, TestBaseMixin, ) @@ -632,13 +631,11 @@ def test_pitch_shift_shape(self, n_steps): waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps) assert waveform.size() == waveform_shift.size() - @skipIfRocm def test_rnnt_loss_basic_backward(self): logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device) loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths) loss.backward() - @skipIfRocm def test_rnnt_loss_basic_forward_no_grad(self): """In early stage, calls to `rnnt_loss` resulted in segmentation fault when `logits` have `requires_grad = False`. This test makes sure that this no longer @@ -658,7 +655,6 @@ def test_rnnt_loss_basic_forward_no_grad(self): (rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2), ] ) - @skipIfRocm def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol): data, ref_costs, ref_gradients = data_func( dtype=dtype, @@ -673,7 +669,6 @@ def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol): ) @parameterized.expand([(True,), (False,)]) - @skipIfRocm def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self, fused_log_softmax): seed = 777 for i in range(5): diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 3b65f15017..66bf1739a0 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -796,7 +796,6 @@ def test_deemphasis(self): class FunctionalFloat32Only(TestBaseMixin): - @skipIfRocm def test_rnnt_loss(self): def func(tensor): targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 72dae54437..e3cf93ba39 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -5,14 +5,7 @@ import torchaudio.transforms as T from parameterized import parameterized from torch.autograd import gradcheck, gradgradcheck -from torchaudio_unittest.common_utils import ( - TestBaseMixin, - get_spectrogram, - get_whitenoise, - nested_params, - rnnt_utils, - skipIfRocm, -) +from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, rnnt_utils, TestBaseMixin class _DeterministicWrapper(torch.nn.Module): @@ -80,7 +73,6 @@ def test_spectrogram(self, kwargs): waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform], nondet_tol=1e-10) - @skipIfRocm def test_inverse_spectrogram(self): # create a realistic input: waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) @@ -104,7 +96,6 @@ def test_melspectrogram(self): [0, 0.99], [False, True], ) - @skipIfRocm def test_griffinlim(self, momentum, rand_init): n_fft = 80 power = 1 @@ -124,7 +115,6 @@ def test_mfcc(self, log_mels): self.assert_grad(transform, [waveform], nondet_tol=1e-10) @parameterized.expand([(False,), (True,)]) - @skipIfRocm def test_lfcc(self, log_lf): sample_rate = 8000 transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 4a6281669b..6ada3351ab 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -273,7 +273,6 @@ def test_deemphasis(self): class TransformsFloat32Only(TestBaseMixin): - @skipIfRocm def test_rnnt_loss(self): logits = torch.tensor( [ diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index 447dd5091d..ea5be54a8a 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -34,7 +34,7 @@ def _get_build(var, default=False): _USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None) _USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None) _BUILD_ALIGN = _get_build("BUILD_ALIGN", True) -_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA) +_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA or _USE_ROCM) _USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info() _TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None) @@ -71,6 +71,18 @@ def get_ext_modules(): extension = CUDAExtension extra_compile_args["cxx"].append("-DUSE_CUDA") extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"] + if _USE_ROCM: + extension = CUDAExtension + extra_compile_args["cxx"].append("-DHIPBLAS_V2") + extra_compile_args["nvcc"] = ["-O3", "-DHIPBLAS_V2"] + # TORCH_HIP_VERSION is used by hipified C++ (e.g. utils_hip.cpp); PyTorch only defines it when building PyTorch. + if torch.version.hip: + parts = torch.version.hip.split(".") + major = int(parts[0]) if len(parts) > 0 else 0 + minor = int(parts[1]) if len(parts) > 1 else 0 + torch_hip_version = major * 100 + minor # e.g. 7.1.x -> 701 + extra_compile_args["cxx"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) + extra_compile_args["nvcc"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) sources = [ "utils.cpp", @@ -78,7 +90,7 @@ def get_ext_modules(): "overdrive.cpp", ] - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("iir_cuda.cu") if _BUILD_RNNT: @@ -88,7 +100,7 @@ def get_ext_modules(): "rnnt/compute.cpp", ] ) - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("rnnt/gpu/compute.cu") if _BUILD_ALIGN: @@ -99,7 +111,7 @@ def get_ext_modules(): "forced_align/compute.cpp", ] ) - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("forced_align/gpu/compute.cu") modules = [