Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ TEST_F(AlgorithmTest, Algorithm3xBF16) {
}

TEST_F(AlgorithmTest, Algorithm6xBF16) {
if (GpuComputeComp().IsRocm()) {
if (GpuComputeComp().rocm_compute_capability()->gfx9_mi200())
GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on MI200.";
}
constexpr absl::string_view kHloText = R"(
HloModule Algorithm6xBF16

Expand Down Expand Up @@ -1055,6 +1059,19 @@ class NumericTestsForBlas : public BlasAlgorithmTest,
)";

protected:
void SetUp() override {
PC::Algorithm algorithm = GetParam();
if (GpuComputeComp().IsRocm()) {
if (GpuComputeComp().rocm_compute_capability()->gfx9_mi200() &&
(algorithm == PC::ALG_DOT_BF16_BF16_F32_X3 ||
algorithm == PC::ALG_DOT_BF16_BF16_F32_X6 ||
algorithm == PC::ALG_DOT_BF16_BF16_F32_X9)) {
GTEST_SKIP() << AlgorithmToString(GetParam())
<< " not supported on MI200.";
}
}
}

std::string algorithm_;
};

Expand Down Expand Up @@ -1551,10 +1568,17 @@ TEST_P(TritonAndBlasSupportForDifferentTensorSizes,
case PC::ALG_DOT_BF16_BF16_F32_X6:
case PC::ALG_DOT_BF16_BF16_F32_X9:
if (GpuComputeComp().IsRocm()) {
// X6 and X9 algorithms on ROCm marked as not supported
// because they often require too much shared memory.
EXPECT_FALSE(result_or_status.value())
<< "algorithms not supported on ROCm";
if (result_or_status.status().ok()) {
// X6 and X9 algorithms on ROCm marked as not supported
// because they often require too much shared memory.
EXPECT_FALSE(result_or_status.value())
<< "algorithms not supported on ROCm";
} else {
if (GpuComputeComp().rocm_compute_capability()->gfx9_mi200()) {
EXPECT_EQ(result_or_status.status().code(),
absl::StatusCode::kInternal);
}
}
} else {
ASSERT_TRUE(result_or_status.status().ok())
<< "failed to compile " << algorithm_;
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ CublasLtMatmulThunk::GetCachedMatmulPlan(const ExecuteParams& params) {
auto algorithms,
plan->GetAlgorithms(params.stream, num_algorithms, max_workspace));

if (algorithms.empty()) {
return absl::InternalError(
"Failed to get a MatmulPlan: no valid algorithm found.");
}
TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_]));
return std::move(plan);
};
Expand Down
6 changes: 6 additions & 0 deletions xla/service/gpu/dot_algorithm_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) {
if (params.backend_restriction == BackendRestriction::kTritonOnly) {
GTEST_SKIP() << "TODO: Triton unsupported in ROCm";
}
if (rcc->gfx9_mi200() &&
(params.algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
params.algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9)) {
GTEST_SKIP() << AlgorithmToString(params.algorithm)
<< " not supported on MI200.";
}
}

// CublasLt does not support FP8 fast accumulation.
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,9 @@ XLA_FFI_DEFINE_HANDLER(

TEST_F(GpuCompilerTest,
ParametersUsedByCollectiveMosaicShouldBeCopiedToCollectiveMemory) {
if (device_description().gpu_compute_capability().IsRocm()) {
GTEST_SKIP() << "Mosaic GPU is not supported on ROCm.";
}
XLA_FFI_Handler_Bundle bundle = {
/*instantiate=*/nullptr,
/*prepare=*/nullptr,
Expand Down
Loading