diff --git a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc index 1419415644f71..e5b3abdbbcf87 100644 --- a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc +++ b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc @@ -155,6 +155,10 @@ TEST_F(AlgorithmTest, Algorithm3xBF16) { } TEST_F(AlgorithmTest, Algorithm6xBF16) { + if (GpuComputeComp().IsRocm()) { + if (GpuComputeComp().rocm_compute_capability()->gfx9_mi200()) + GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on MI200."; + } constexpr absl::string_view kHloText = R"( HloModule Algorithm6xBF16 @@ -1055,6 +1059,19 @@ class NumericTestsForBlas : public BlasAlgorithmTest, )"; protected: + void SetUp() override { + PC::Algorithm algorithm = GetParam(); + if (GpuComputeComp().IsRocm()) { + if (GpuComputeComp().rocm_compute_capability()->gfx9_mi200() && + (algorithm == PC::ALG_DOT_BF16_BF16_F32_X3 || + algorithm == PC::ALG_DOT_BF16_BF16_F32_X6 || + algorithm == PC::ALG_DOT_BF16_BF16_F32_X9)) { + GTEST_SKIP() << AlgorithmToString(GetParam()) + << " not supported on MI200."; + } + } + } + std::string algorithm_; }; @@ -1551,10 +1568,17 @@ TEST_P(TritonAndBlasSupportForDifferentTensorSizes, case PC::ALG_DOT_BF16_BF16_F32_X6: case PC::ALG_DOT_BF16_BF16_F32_X9: if (GpuComputeComp().IsRocm()) { - // X6 and X9 algorithms on ROCm marked as not supported - // because they often require too much shared memory. - EXPECT_FALSE(result_or_status.value()) - << "algorithms not supported on ROCm"; + if (result_or_status.status().ok()) { + // X6 and X9 algorithms on ROCm marked as not supported + // because they often require too much shared memory. + EXPECT_FALSE(result_or_status.value()) + << "algorithms not supported on ROCm"; + } else { + if (GpuComputeComp().rocm_compute_capability()->gfx9_mi200()) { + EXPECT_EQ(result_or_status.status().code(), + absl::StatusCode::kInternal); + } + } } else { ASSERT_TRUE(result_or_status.status().ok()) << "failed to compile " << algorithm_; diff --git a/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc b/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc index 19e8e197c68db..fefd7875489fb 100644 --- a/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc +++ b/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc @@ -207,6 +207,10 @@ CublasLtMatmulThunk::GetCachedMatmulPlan(const ExecuteParams& params) { auto algorithms, plan->GetAlgorithms(params.stream, num_algorithms, max_workspace)); + if (algorithms.empty()) { + return absl::InternalError( + "Failed to get a MatmulPlan: no valid algorithm found."); + } TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); return std::move(plan); }; diff --git a/xla/service/gpu/dot_algorithm_support_test.cc b/xla/service/gpu/dot_algorithm_support_test.cc index adf8d27015b99..41ec10ef34292 100644 --- a/xla/service/gpu/dot_algorithm_support_test.cc +++ b/xla/service/gpu/dot_algorithm_support_test.cc @@ -180,6 +180,12 @@ TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) { if (params.backend_restriction == BackendRestriction::kTritonOnly) { GTEST_SKIP() << "TODO: Triton unsupported in ROCm"; } + if (rcc->gfx9_mi200() && + (params.algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || + params.algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9)) { + GTEST_SKIP() << AlgorithmToString(params.algorithm) + << " not supported on MI200."; + } } // CublasLt does not support FP8 fast accumulation. diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 1f5d58da3d2e4..fd8ad865f5358 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -1951,6 +1951,9 @@ XLA_FFI_DEFINE_HANDLER( TEST_F(GpuCompilerTest, ParametersUsedByCollectiveMosaicShouldBeCopiedToCollectiveMemory) { + if (device_description().gpu_compute_capability().IsRocm()) { + GTEST_SKIP() << "Mosaic GPU is not supported on ROCm."; + } XLA_FFI_Handler_Bundle bundle = { /*instantiate=*/nullptr, /*prepare=*/nullptr,