From b232095931f964981eeeae93004289e5daf4266c Mon Sep 17 00:00:00 2001 From: skishore Date: Thu, 19 Feb 2026 13:47:22 +0000 Subject: [PATCH 01/12] add source code for rnnt loss, define TORCH_HIP_VERSION variable --- tools/setup_helpers/extension.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index 447dd5091d..f1bcfc2275 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -71,6 +71,19 @@ def get_ext_modules(): extension = CUDAExtension extra_compile_args["cxx"].append("-DUSE_CUDA") extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"] + if _USE_ROCM: + extension = CUDAExtension + extra_compile_args["cxx"].append("-DHIPBLAS_V2") + extra_compile_args["nvcc"] = ["-O3", "-DHIPBLAS_V2"] + # TORCH_HIP_VERSION is used by hipified C++ (e.g. utils_hip.cpp); PyTorch only defines it when building PyTorch. + if torch.version.hip: + parts = torch.version.hip.split(".") + major = int(parts[0]) if len(parts) > 0 else 0 + minor = int(parts[1]) if len(parts) > 1 else 0 + patch = int(parts[2].split("-")[0]) if len(parts) > 2 else 0 + torch_hip_version = major * 100 + minor * 10 + patch # e.g. 6.0.1 -> 601 + extra_compile_args["cxx"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) + extra_compile_args["nvcc"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) sources = [ "utils.cpp", @@ -88,7 +101,7 @@ def get_ext_modules(): "rnnt/compute.cpp", ] ) - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("rnnt/gpu/compute.cu") if _BUILD_ALIGN: From fa315a15cd035307dc88fa0b1070ab87d9089108 Mon Sep 17 00:00:00 2001 From: skishore Date: Thu, 19 Feb 2026 14:02:08 +0000 Subject: [PATCH 02/12] add namespace shim file --- src/libtorchaudio/hip_namespace_shim.h | 30 ++++++++++++++++++++++++++ src/libtorchaudio/rnnt/gpu/compute.cu | 4 ++++ 2 files changed, 34 insertions(+) create mode 100644 src/libtorchaudio/hip_namespace_shim.h diff --git a/src/libtorchaudio/hip_namespace_shim.h b/src/libtorchaudio/hip_namespace_shim.h new file mode 100644 index 0000000000..2c944ec251 --- /dev/null +++ b/src/libtorchaudio/hip_namespace_shim.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace libtorchaudio::hip { + +inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA( + torch::stable::DeviceIndex device_index = -1) { + return cuda::getCurrentHIPStreamMasqueradingAsCUDA(device_index); +} + +inline void setCurrentHIPStreamMasqueradingAsCUDA( + hipStream_t stream, + torch::stable::DeviceIndex device_index = -1) { + cuda::setCurrentHIPStreamMasqueradingAsCUDA(stream, device_index); +} + +inline hipStream_t getStreamFromPoolMasqueradingAsCUDA( + const bool isHighPriority = false, + torch::stable::DeviceIndex device_index = -1) { + return cuda::getStreamFromPoolMasqueradingAsCUDA(isHighPriority, device_index); +} + +inline void synchronize( + hipStream_t stream, + torch::stable::DeviceIndex device_index = -1) { + cuda::synchronize(stream, device_index); +} + +} // namespace libtorchaudio::hip \ No newline at end of file diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 7e99fec395..2c4ce205a8 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -6,6 +6,10 @@ #include #include +#ifdef USE_ROCM +#include +#endif + namespace torchaudio { namespace rnnt { namespace gpu { From c13048ae16912b90bfd24225d27691499c8c8641 Mon Sep 17 00:00:00 2001 From: skishore Date: Fri, 20 Feb 2026 13:43:10 +0000 Subject: [PATCH 03/12] fix THO_DISPATCH syntax so that hipification works --- src/libtorchaudio/iir_cuda.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/libtorchaudio/iir_cuda.cu b/src/libtorchaudio/iir_cuda.cu index 658bca4c54..31919ee617 100644 --- a/src/libtorchaudio/iir_cuda.cu +++ b/src/libtorchaudio/iir_cuda.cu @@ -69,12 +69,12 @@ Tensor cuda_lfilter_core_loop( const dim3 blocks((N * C + threads.x - 1) / threads.x); THO_DISPATCH_V2( - in.scalar_type(), "iir_cu_loop", AT_WRAP([&] { - (iir_cu_kernel<<>>( - torchaudio::packed_accessor_size_t(in), - torchaudio::packed_accessor_size_t(a_flipped), - torchaudio::packed_accessor_size_t(padded_out))); - STD_CUDA_KERNEL_LAUNCH_CHECK(); - }), AT_FLOATING_TYPES); + in.scalar_type(), "iir_cu_loop", AT_WRAP(([&]() { + iir_cu_kernel<<>>( + torchaudio::packed_accessor_size_t(in), + torchaudio::packed_accessor_size_t(a_flipped), + torchaudio::packed_accessor_size_t(padded_out)); + STD_CUDA_KERNEL_LAUNCH_CHECK(); + })), AT_FLOATING_TYPES); return padded_out; } From 6c8ab7f853789e7a924ae4df2398b369a1d7d216 Mon Sep 17 00:00:00 2001 From: skishore Date: Fri, 20 Feb 2026 13:43:56 +0000 Subject: [PATCH 04/12] fix torch version, add lfilter rocm to sources --- tools/setup_helpers/extension.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index f1bcfc2275..01e24e1fa7 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -80,8 +80,7 @@ def get_ext_modules(): parts = torch.version.hip.split(".") major = int(parts[0]) if len(parts) > 0 else 0 minor = int(parts[1]) if len(parts) > 1 else 0 - patch = int(parts[2].split("-")[0]) if len(parts) > 2 else 0 - torch_hip_version = major * 100 + minor * 10 + patch # e.g. 6.0.1 -> 601 + torch_hip_version = major * 100 + minor # e.g. 7.1.x -> 701 extra_compile_args["cxx"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) extra_compile_args["nvcc"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version)) @@ -91,7 +90,7 @@ def get_ext_modules(): "overdrive.cpp", ] - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("iir_cuda.cu") if _BUILD_RNNT: From 5b46635a0b8c28b1dd334872f39dba6a91c80b56 Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Sat, 21 Feb 2026 03:24:58 -0600 Subject: [PATCH 05/12] add rocm source code for forced align, add shim namespace file for rocm --- src/libtorchaudio/forced_align/gpu/compute.cu | 4 ++++ src/libtorchaudio/rnnt/gpu/compute.cu | 2 +- tools/setup_helpers/extension.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index 444a4f8f6d..c3be77a256 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -8,6 +8,10 @@ #include #include +#ifdef USE_ROCM + #include +#endif + namespace { constexpr int kNumThreads = 1024; // Number of threads to run CUDA kernel in parallel. diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 2c4ce205a8..40c5945f81 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -7,7 +7,7 @@ #include #ifdef USE_ROCM -#include + #include #endif namespace torchaudio { diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index 01e24e1fa7..1c8fd299a0 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -111,7 +111,7 @@ def get_ext_modules(): "forced_align/compute.cpp", ] ) - if _USE_CUDA: + if _USE_CUDA or _USE_ROCM: sources.append("forced_align/gpu/compute.cu") modules = [ From ab3424f7dbea75d9449a734eabe0d4bc29a5296a Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Sat, 21 Feb 2026 11:01:49 -0600 Subject: [PATCH 06/12] remove extra skip if rocm flags for certain unit tests --- test/torchaudio_unittest/functional/functional_cuda_test.py | 5 +---- test/torchaudio_unittest/functional/functional_impl.py | 5 ----- .../functional/torchscript_consistency_impl.py | 1 - .../transforms/torchscript_consistency_impl.py | 1 - 4 files changed, 1 insertion(+), 11 deletions(-) diff --git a/test/torchaudio_unittest/functional/functional_cuda_test.py b/test/torchaudio_unittest/functional/functional_cuda_test.py index ad74db8279..037b053099 100644 --- a/test/torchaudio_unittest/functional/functional_cuda_test.py +++ b/test/torchaudio_unittest/functional/functional_cuda_test.py @@ -1,8 +1,7 @@ import unittest import torch - -from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfSingleCuda, skipIfRocm +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfSingleCuda from .functional_impl import Functional, FunctionalCUDAOnly @@ -13,13 +12,11 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): device = torch.device("cuda") @unittest.expectedFailure - @skipIfRocm def test_lfilter_9th_order_filter_stability(self): super().test_lfilter_9th_order_filter_stability() @skipIfNoCuda -@skipIfRocm class TestLFilterFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device("cuda") diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 50e403fd34..e2e42f1fe7 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -14,7 +14,6 @@ get_whitenoise, nested_params, rnnt_utils, - skipIfRocm, TestBaseMixin, ) @@ -632,13 +631,11 @@ def test_pitch_shift_shape(self, n_steps): waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps) assert waveform.size() == waveform_shift.size() - @skipIfRocm def test_rnnt_loss_basic_backward(self): logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device) loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths) loss.backward() - @skipIfRocm def test_rnnt_loss_basic_forward_no_grad(self): """In early stage, calls to `rnnt_loss` resulted in segmentation fault when `logits` have `requires_grad = False`. This test makes sure that this no longer @@ -658,7 +655,6 @@ def test_rnnt_loss_basic_forward_no_grad(self): (rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2), ] ) - @skipIfRocm def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol): data, ref_costs, ref_gradients = data_func( dtype=dtype, @@ -673,7 +669,6 @@ def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol): ) @parameterized.expand([(True,), (False,)]) - @skipIfRocm def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self, fused_log_softmax): seed = 777 for i in range(5): diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 3b65f15017..66bf1739a0 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -796,7 +796,6 @@ def test_deemphasis(self): class FunctionalFloat32Only(TestBaseMixin): - @skipIfRocm def test_rnnt_loss(self): def func(tensor): targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 4a6281669b..6ada3351ab 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -273,7 +273,6 @@ def test_deemphasis(self): class TransformsFloat32Only(TestBaseMixin): - @skipIfRocm def test_rnnt_loss(self): logits = torch.tensor( [ From 45aec763c8e82c279a4dad4d708e3fbfe6448f41 Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Sat, 21 Feb 2026 12:24:42 -0600 Subject: [PATCH 07/12] build cuda ctc decoder for rocm --- tools/setup_helpers/extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index 1c8fd299a0..ea5be54a8a 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -34,7 +34,7 @@ def _get_build(var, default=False): _USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None) _USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None) _BUILD_ALIGN = _get_build("BUILD_ALIGN", True) -_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA) +_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA or _USE_ROCM) _USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info() _TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None) From bc95506f17d2dc43353f893043f4bf0a97d1a4a7 Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Sat, 21 Feb 2026 12:29:06 -0600 Subject: [PATCH 08/12] the lfilter test passes, the other tests are the same, so removing skip rocm so it is the same as upstream --- test/torchaudio_unittest/functional/functional_cpu_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index c953277518..f36fd2f5b3 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -1,7 +1,7 @@ import unittest import torch -from torchaudio_unittest.common_utils import PytorchTestCase, skipIfRocm +from torchaudio_unittest.common_utils import PytorchTestCase from .functional_impl import Functional, FunctionalCPUOnly @@ -11,11 +11,9 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase): device = torch.device("cpu") @unittest.expectedFailure - @skipIfRocm def test_lfilter_9th_order_filter_stability(self): super().test_lfilter_9th_order_filter_stability() -@skipIfRocm class TestFunctionalFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device("cpu") From 553f9b0c890e90cf3a56d5869376efd4e5b9efd5 Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Mon, 23 Feb 2026 04:49:32 -0600 Subject: [PATCH 09/12] add end of file --- src/libtorchaudio/hip_namespace_shim.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtorchaudio/hip_namespace_shim.h b/src/libtorchaudio/hip_namespace_shim.h index 2c944ec251..38ba0b32bd 100644 --- a/src/libtorchaudio/hip_namespace_shim.h +++ b/src/libtorchaudio/hip_namespace_shim.h @@ -27,4 +27,4 @@ inline void synchronize( cuda::synchronize(stream, device_index); } -} // namespace libtorchaudio::hip \ No newline at end of file +} // namespace libtorchaudio::hip From f795931af4435465caed47129b2cbcbad9adaa8a Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Mon, 23 Feb 2026 05:28:08 -0600 Subject: [PATCH 10/12] fix ufmt issue --- test/torchaudio_unittest/functional/functional_cpu_test.py | 1 + test/torchaudio_unittest/transforms/autograd_test_impl.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index f36fd2f5b3..c3b14c38db 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -14,6 +14,7 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase): def test_lfilter_9th_order_filter_stability(self): super().test_lfilter_9th_order_filter_stability() + class TestFunctionalFloat64(Functional, PytorchTestCase): dtype = torch.float64 device = torch.device("cpu") diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 72dae54437..208e39e5b7 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -6,12 +6,12 @@ from parameterized import parameterized from torch.autograd import gradcheck, gradgradcheck from torchaudio_unittest.common_utils import ( - TestBaseMixin, get_spectrogram, get_whitenoise, nested_params, rnnt_utils, skipIfRocm, + TestBaseMixin, ) From e7bad3510fcd63350da7a426b4d98b85b609332a Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Mon, 23 Feb 2026 08:11:16 -0600 Subject: [PATCH 11/12] fix clang format --- src/libtorchaudio/hip_namespace_shim.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/libtorchaudio/hip_namespace_shim.h b/src/libtorchaudio/hip_namespace_shim.h index 38ba0b32bd..7d03d919a7 100644 --- a/src/libtorchaudio/hip_namespace_shim.h +++ b/src/libtorchaudio/hip_namespace_shim.h @@ -18,7 +18,8 @@ inline void setCurrentHIPStreamMasqueradingAsCUDA( inline hipStream_t getStreamFromPoolMasqueradingAsCUDA( const bool isHighPriority = false, torch::stable::DeviceIndex device_index = -1) { - return cuda::getStreamFromPoolMasqueradingAsCUDA(isHighPriority, device_index); + return cuda::getStreamFromPoolMasqueradingAsCUDA( + isHighPriority, device_index); } inline void synchronize( From 82acc5278b1e2a23df6d70a78416f5497a5e689e Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Mon, 23 Feb 2026 10:46:26 -0600 Subject: [PATCH 12/12] remove skip if rocm flag from the test --- .../transforms/autograd_test_impl.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 208e39e5b7..e3cf93ba39 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -5,14 +5,7 @@ import torchaudio.transforms as T from parameterized import parameterized from torch.autograd import gradcheck, gradgradcheck -from torchaudio_unittest.common_utils import ( - get_spectrogram, - get_whitenoise, - nested_params, - rnnt_utils, - skipIfRocm, - TestBaseMixin, -) +from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, rnnt_utils, TestBaseMixin class _DeterministicWrapper(torch.nn.Module): @@ -80,7 +73,6 @@ def test_spectrogram(self, kwargs): waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform], nondet_tol=1e-10) - @skipIfRocm def test_inverse_spectrogram(self): # create a realistic input: waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) @@ -104,7 +96,6 @@ def test_melspectrogram(self): [0, 0.99], [False, True], ) - @skipIfRocm def test_griffinlim(self, momentum, rand_init): n_fft = 80 power = 1 @@ -124,7 +115,6 @@ def test_mfcc(self, log_mels): self.assert_grad(transform, [waveform], nondet_tol=1e-10) @parameterized.expand([(False,), (True,)]) - @skipIfRocm def test_lfcc(self, log_lf): sample_rate = 8000 transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf)