Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/libtorchaudio/forced_align/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include <cub/cub.cuh>
#include <limits.h>

#ifdef USE_ROCM
#include <libtorchaudio/hip_namespace_shim.h>
#endif

namespace {
constexpr int kNumThreads =
1024; // Number of threads to run CUDA kernel in parallel.
Expand Down
31 changes: 31 additions & 0 deletions src/libtorchaudio/hip_namespace_shim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include <libtorchaudio/hip_utils.h>

namespace libtorchaudio::hip {

inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA(
torch::stable::DeviceIndex device_index = -1) {
return cuda::getCurrentHIPStreamMasqueradingAsCUDA(device_index);
}

inline void setCurrentHIPStreamMasqueradingAsCUDA(
hipStream_t stream,
torch::stable::DeviceIndex device_index = -1) {
cuda::setCurrentHIPStreamMasqueradingAsCUDA(stream, device_index);
}

inline hipStream_t getStreamFromPoolMasqueradingAsCUDA(
const bool isHighPriority = false,
torch::stable::DeviceIndex device_index = -1) {
return cuda::getStreamFromPoolMasqueradingAsCUDA(
isHighPriority, device_index);
}

inline void synchronize(
hipStream_t stream,
torch::stable::DeviceIndex device_index = -1) {
cuda::synchronize(stream, device_index);
}

} // namespace libtorchaudio::hip
14 changes: 7 additions & 7 deletions src/libtorchaudio/iir_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ Tensor cuda_lfilter_core_loop(
const dim3 blocks((N * C + threads.x - 1) / threads.x);

THO_DISPATCH_V2(
in.scalar_type(), "iir_cu_loop", AT_WRAP([&] {
(iir_cu_kernel<scalar_t><<<blocks, threads>>>(
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
STD_CUDA_KERNEL_LAUNCH_CHECK();
}), AT_FLOATING_TYPES);
in.scalar_type(), "iir_cu_loop", AT_WRAP(([&]() {
iir_cu_kernel<scalar_t><<<blocks, threads>>>(
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out));
STD_CUDA_KERNEL_LAUNCH_CHECK();
})), AT_FLOATING_TYPES);
return padded_out;
}
4 changes: 4 additions & 0 deletions src/libtorchaudio/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/Dispatch_v2.h>

#ifdef USE_ROCM
#include <libtorchaudio/hip_namespace_shim.h>
#endif

namespace torchaudio {
namespace rnnt {
namespace gpu {
Expand Down
5 changes: 2 additions & 3 deletions test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfRocm
from torchaudio_unittest.common_utils import PytorchTestCase

from .functional_impl import Functional, FunctionalCPUOnly

Expand All @@ -11,11 +11,10 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase):
device = torch.device("cpu")

@unittest.expectedFailure
@skipIfRocm
def test_lfilter_9th_order_filter_stability(self):
super().test_lfilter_9th_order_filter_stability()

@skipIfRocm

class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
5 changes: 1 addition & 4 deletions test/torchaudio_unittest/functional/functional_cuda_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest

import torch

from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfSingleCuda, skipIfRocm
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfSingleCuda

from .functional_impl import Functional, FunctionalCUDAOnly

Expand All @@ -13,13 +12,11 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
device = torch.device("cuda")

@unittest.expectedFailure
@skipIfRocm
def test_lfilter_9th_order_filter_stability(self):
super().test_lfilter_9th_order_filter_stability()


@skipIfNoCuda
@skipIfRocm
class TestLFilterFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
Expand Down
5 changes: 0 additions & 5 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
get_whitenoise,
nested_params,
rnnt_utils,
skipIfRocm,
TestBaseMixin,
)

Expand Down Expand Up @@ -632,13 +631,11 @@ def test_pitch_shift_shape(self, n_steps):
waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps)
assert waveform.size() == waveform_shift.size()

@skipIfRocm
def test_rnnt_loss_basic_backward(self):
logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward()

@skipIfRocm
def test_rnnt_loss_basic_forward_no_grad(self):
"""In early stage, calls to `rnnt_loss` resulted in segmentation fault when
`logits` have `requires_grad = False`. This test makes sure that this no longer
Expand All @@ -658,7 +655,6 @@ def test_rnnt_loss_basic_forward_no_grad(self):
(rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2),
]
)
@skipIfRocm
def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
data, ref_costs, ref_gradients = data_func(
dtype=dtype,
Expand All @@ -673,7 +669,6 @@ def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
)

@parameterized.expand([(True,), (False,)])
@skipIfRocm
def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self, fused_log_softmax):
seed = 777
for i in range(5):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ def test_deemphasis(self):


class FunctionalFloat32Only(TestBaseMixin):
@skipIfRocm
def test_rnnt_loss(self):
def func(tensor):
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
Expand Down
12 changes: 1 addition & 11 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@
import torchaudio.transforms as T
from parameterized import parameterized
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_spectrogram,
get_whitenoise,
nested_params,
rnnt_utils,
skipIfRocm,
)
from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, rnnt_utils, TestBaseMixin


class _DeterministicWrapper(torch.nn.Module):
Expand Down Expand Up @@ -80,7 +73,6 @@ def test_spectrogram(self, kwargs):
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@skipIfRocm
def test_inverse_spectrogram(self):
# create a realistic input:
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
Expand All @@ -104,7 +96,6 @@ def test_melspectrogram(self):
[0, 0.99],
[False, True],
)
@skipIfRocm
def test_griffinlim(self, momentum, rand_init):
n_fft = 80
power = 1
Expand All @@ -124,7 +115,6 @@ def test_mfcc(self, log_mels):
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@parameterized.expand([(False,), (True,)])
@skipIfRocm
def test_lfcc(self, log_lf):
sample_rate = 8000
transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def test_deemphasis(self):


class TransformsFloat32Only(TestBaseMixin):
@skipIfRocm
def test_rnnt_loss(self):
logits = torch.tensor(
[
Expand Down
20 changes: 16 additions & 4 deletions tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_build(var, default=False):
_USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None)
_BUILD_ALIGN = _get_build("BUILD_ALIGN", True)
_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA)
_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA or _USE_ROCM)
_USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info()
_TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None)

Expand Down Expand Up @@ -71,14 +71,26 @@ def get_ext_modules():
extension = CUDAExtension
extra_compile_args["cxx"].append("-DUSE_CUDA")
extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"]
if _USE_ROCM:
extension = CUDAExtension
extra_compile_args["cxx"].append("-DHIPBLAS_V2")
extra_compile_args["nvcc"] = ["-O3", "-DHIPBLAS_V2"]
# TORCH_HIP_VERSION is used by hipified C++ (e.g. utils_hip.cpp); PyTorch only defines it when building PyTorch.
if torch.version.hip:
parts = torch.version.hip.split(".")
major = int(parts[0]) if len(parts) > 0 else 0
minor = int(parts[1]) if len(parts) > 1 else 0
torch_hip_version = major * 100 + minor # e.g. 7.1.x -> 701
extra_compile_args["cxx"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version))
extra_compile_args["nvcc"].append("-DTORCH_HIP_VERSION=" + str(torch_hip_version))

sources = [
"utils.cpp",
"lfilter.cpp",
"overdrive.cpp",
]

if _USE_CUDA:
if _USE_CUDA or _USE_ROCM:
sources.append("iir_cuda.cu")

if _BUILD_RNNT:
Expand All @@ -88,7 +100,7 @@ def get_ext_modules():
"rnnt/compute.cpp",
]
)
if _USE_CUDA:
if _USE_CUDA or _USE_ROCM:
sources.append("rnnt/gpu/compute.cu")

if _BUILD_ALIGN:
Expand All @@ -99,7 +111,7 @@ def get_ext_modules():
"forced_align/compute.cpp",
]
)
if _USE_CUDA:
if _USE_CUDA or _USE_ROCM:
sources.append("forced_align/gpu/compute.cu")

modules = [
Expand Down
Loading