From 2faed86f795cbab6798a9e79b717adcc578a8c54 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 12 Nov 2025 21:39:01 +0000 Subject: [PATCH 01/43] Remove nvidia_wheel_versions --- jaxlib/jax.bzl | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 2e2757f2ecc4..7989c6c3f4c7 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,7 +20,6 @@ load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") -load("@nvidia_wheel_versions//:versions.bzl", "NVIDIA_WHEEL_VERSIONS") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND") load("@rocm_external_test_deps//:external_deps.bzl", "EXTERNAL_DEPS") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") @@ -461,7 +460,6 @@ def _jax_wheel_impl(ctx): if ctx.attr.platform_version == "": fail("platform_version must be set to a valid cuda version for cuda wheels") args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels - args.add("--nvidia_wheel_versions_data", NVIDIA_WHEEL_VERSIONS) # required for gpu wheels if ctx.attr.enable_rocm: args.add("--enable-rocm", "True") if ctx.attr.platform_version == "": From 6ba78d224c398fa530a7d8a2bc844c299a0813b4 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 12 Nov 2025 21:42:44 +0000 Subject: [PATCH 02/43] Make jaxlib targets visible --- jaxlib/BUILD | 2 +- jaxlib/rocm/BUILD | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index d9a3a619965c..639803e667d4 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -39,7 +39,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//jax:internal"], + default_visibility = ["//visibility:public"], ) package_group( diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 72f2ec6dae80..4f6c21a10c33 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -27,7 +27,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//visibility:public"], ) cc_library( From d3f32d9b88d92667f7f65e9157e8019c76de3660 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 12 Nov 2025 22:12:01 +0000 Subject: [PATCH 03/43] hipblas typedef fix --- jaxlib/gpu/solver_interface.cc | 12 ++++++------ jaxlib/gpu/vendor.h | 12 ++++-------- jaxlib/rocm/BUILD | 4 ++++ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index e10c08d54e9f..8866c7bea2fe 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -62,8 +62,8 @@ JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf); JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched); JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuComplex, gpublasCgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuDoubleComplex, gpublasZgetrfBatched); #undef JAX_GPU_DEFINE_GETRF_BATCHED // QR decomposition: geqrf @@ -101,8 +101,8 @@ JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf); JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched); JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuComplex, gpublasCgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuDoubleComplex, gpublasZgeqrfBatched); #undef JAX_GPU_DEFINE_GEQRF_BATCHED // Householder transformations: orgqr @@ -272,8 +272,8 @@ JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd); JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk); JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk); -JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); -JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); +JAX_GPU_DEFINE_SYRK(gpuComplex, gpublasCsyrk); +JAX_GPU_DEFINE_SYRK(gpuDoubleComplex, gpublasZsyrk); #undef JAX_GPU_DEFINE_SYRK // Singular Value Decomposition: gesvd diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 4e6c4ca9a7d4..fbe5b30daadf 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -446,6 +446,8 @@ inline constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) +#define HIPBLAS_V2 1 + // IWYU pragma: begin_exports #include "rocm/include/hip/hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" @@ -466,17 +468,11 @@ inline constexpr uint32_t kNumThreadsPerWarp = 32; // MIOpen lib. Remove when MIOpen support is complete. #define MIOPEN_STATUS_SUCCESS 0 -typedef hipFloatComplex gpuComplex; +typedef hipComplex gpuComplex; typedef hipDoubleComplex gpuDoubleComplex; -#if TF_ROCM_VERSION >= 70000 -typedef hipFloatComplex gpublasComplex; +typedef hipComplex gpublasComplex; typedef hipDoubleComplex gpublasDoubleComplex; -#else -typedef hipblasComplex gpublasComplex; -typedef hipblasDoubleComplex gpublasDoubleComplex; -#endif // TF_ROCM_VERSION >= 70000 - typedef struct hipsolverHandle_* gpusolverDnHandle_t; typedef hipblasFillMode_t gpublasFillMode_t; typedef hipsolverFillMode_t gpusolverFillMode_t; diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 4f6c21a10c33..bd810f4a74cf 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -404,6 +404,10 @@ nanobind_extension( "@nanobind", "@xla//xla/ffi/api:ffi", ], + linkopts = [ + "-L/opt/rocm/lib", + "-lamdhip64", + ], ) cc_library( From e874827942b6950635f2a0e8d1e98c25e965958b Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 13 Nov 2025 21:50:02 +0000 Subject: [PATCH 04/43] No GPU fail --- jaxlib/rocm/BUILD | 1 + jaxlib/rocm/rocm_plugin_extension.cc | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index bd810f4a74cf..f266baeb4683 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -535,6 +535,7 @@ nanobind_extension( srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":hip_gpu_kernel_helpers", ":py_client_gpu", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 05be1e81c858..e28e4927b81f 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -23,6 +23,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_plugin_extension.h" #include "jaxlib/gpu/py_client_gpu.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" namespace nb = nanobind; @@ -96,6 +97,13 @@ nb::dict FfiHandlers() { return dict; } +int ROCmDeviceCount() { + int device_count = -1; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipInit(0))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipGetDeviceCount(&device_count))); + return device_count; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { @@ -122,5 +130,6 @@ NB_MODULE(rocm_plugin_extension, m) { return device_ordinal; }, nb::arg("data_value")); + m.def("get_device_count", &ROCmDeviceCount); } } // namespace jax From a3338bb6d6ba6c94a64d6ca469e1fedd55d67590 Mon Sep 17 00:00:00 2001 From: Marco Minutoli Date: Thu, 12 Feb 2026 14:06:16 -0800 Subject: [PATCH 05/43] Wrap HIP inline functions in anonymous namespaces in vendor.h --- jaxlib/gpu/vendor.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index fbe5b30daadf..915a36e5bdd3 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -529,6 +529,7 @@ inline hipblasStatus_t gpublasCreate(gpublasHandle_t* handle) { return hipblasCreate(reinterpret_cast(handle)); } } // namespace jax::hip + #define gpublasCreate ::jax::hip::gpublasCreate #define gpublasSetStream hipblasSetStream #define gpublasSgeqrfBatched hipblasSgeqrfBatched @@ -585,6 +586,7 @@ inline hipsolverStatus_t gpusolverDnCreate(gpusolverDnHandle_t* handle) { return hipsolverCreate(reinterpret_cast(handle)); } } // namespace jax::hip + #define gpusolverDnCreate ::jax::hip::gpusolverDnCreate #define gpusolverDnSetStream hipsolverSetStream #define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo From ce4329934142d8d4eb340a0bed54ed85f5d6a080 Mon Sep 17 00:00:00 2001 From: Dragoslav Sicarov Date: Tue, 10 Jun 2025 04:28:40 +0000 Subject: [PATCH 06/43] SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError --- jaxlib/gpu/vendor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 915a36e5bdd3..610a15aecb52 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -772,7 +772,7 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking #define gpuMalloc hipMalloc -#define gpuGetLastError hipGetLastError +#define gpuGetLastError hipExtGetLastError #define gpuGetErrorString hipGetErrorString #define gpuMemcpyAsync hipMemcpyAsync #define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice From c3e235729752a52eb6d9e76fc98bfbe17d7976bb Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Fri, 14 Nov 2025 19:43:09 +0000 Subject: [PATCH 07/43] Add shared utility function get_rocm_version to test_util.py --- jax/_src/test_util.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index df5e1433cd94..10d65eb78cf6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -25,6 +25,7 @@ import logging import math import os +from pathlib import Path import platform import re import sys @@ -392,6 +393,15 @@ def supported_dtypes(): def is_device_rocm(): return 'rocm' in xla_bridge.get_backend().platform_version +def get_rocm_version(): + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + version_path = Path(rocm_path) / ".info" / "version" + if not version_path.exists(): + raise FileNotFoundError(f"Expected ROCm version file at {version_path}") + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version From d6acb7206a69551f87f564a1a13a0981ff545c62 Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Mon, 17 Nov 2025 19:34:29 +0000 Subject: [PATCH 08/43] Fix hipSparse CSR algorithm mappings for ROCm 7 --- jaxlib/gpu/vendor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 610a15aecb52..be985a8a2306 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -756,10 +756,10 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I #define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I #define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT -#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_ALG_DEFAULT -#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_ALG_DEFAULT +#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_COOMV_ALG +#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_CSRMV_ALG1 +#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_COO_ALG1 +#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_CSR_ALG1 #define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO #define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE #define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE From c3cf5d37a58ff3b74ef35eabc3104715e5c0e71b Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Thu, 20 Nov 2025 01:30:24 +0200 Subject: [PATCH 09/43] =?UTF-8?q?Fix=20v=5Fpages=20quantization=20and=20ad?= =?UTF-8?q?just=20test=20params=20for=20ROCm=20compatibilit=E2=80=A6=20(#5?= =?UTF-8?q?60)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/pallas/gpu_paged_attention_test.py | 71 +++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/pallas/gpu_paged_attention_test.py b/tests/pallas/gpu_paged_attention_test.py index 1b778c787a6d..6a1d8de22a78 100644 --- a/tests/pallas/gpu_paged_attention_test.py +++ b/tests/pallas/gpu_paged_attention_test.py @@ -112,6 +112,64 @@ class PagedAttentionKernelTest(PallasBaseTest): def setUp(self): super().setUp() + def _estimate_shared_memory_bytes(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Estimate shared memory usage for paged attention kernel.""" + dtype_size = jnp.dtype(dtype).itemsize + # Approximate calculation based on kernel's memory usage + # Q block: block_h * head_dim + # K/V blocks: pages_per_compute_block * page_size * head_dim + # Plus accumulators and intermediate values + block_k = pages_per_compute_block * page_size + estimated = dtype_size * ( + block_h * head_dim + # Q + 2 * block_k * head_dim + # K and V + block_h * block_k + # logits/attention weights + block_h * 8 # accumulators (m, l, etc.) in float32 + ) + return estimated + + def _adjust_params_for_shared_memory(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Adjust parameters to fit within device shared memory limits. + + Uses XLA's DeviceDescription.shared_memory_per_block_optin() to query + the actual device capability rather than hardcoding values. + """ + try: + device = jax.local_devices()[0] + # Query XLA DeviceDescription for max shared memory per block + # This is exposed from stream_executor::DeviceDescription::shared_memory_per_block_optin() + max_smem = device.shared_memory_per_block_optin + except (AttributeError, IndexError): + # Fallback if XLA doesn't expose shared_memory_per_block_optin (older versions) + # or if no devices are available. Use conservative 48KB (safe for most GPUs). + max_smem = 48 * 1024 + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + # If within limits, no adjustment needed + if estimated <= max_smem: + return block_h, pages_per_compute_block, page_size + + # Try to reduce parameters to fit + while estimated > max_smem: + if pages_per_compute_block > 2: + pages_per_compute_block = pages_per_compute_block // 2 + elif page_size > 8: + page_size = page_size // 2 + elif block_h > 8: + block_h = block_h // 2 + else: + # Can't reduce further, will need to skip + return None, None, None + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + return block_h, pages_per_compute_block, page_size + @jtu.sample_product( dtype=(jnp.float16,), page_size=(8, 16, 32), @@ -201,6 +259,17 @@ def test_quantized_paged_attention( if (quant_dtype == jnp.float8_e4m3fn and not jtu.is_cuda_compute_capability_at_least("8.9")): self.skipTest("Skipping since float8_e4m3fn is not supported on < sm89") + + # Check and adjust parameters if needed to fit device limits for ROCm + if jtu.is_device_rocm(): + adjusted = self._adjust_params_for_shared_memory( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + if adjusted == (None, None, None): + self.skipTest("Cannot adjust parameters to fit ROCm device shared memory limits") + + block_h, pages_per_compute_block, page_size = adjusted + max_kv_len = 2048 seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32) q, k_pages, v_pages, block_tables = _generate_qkv( @@ -218,7 +287,7 @@ def test_quantized_paged_attention( k_, k_scales = (_quantize(k_pages, quant_dtype) if quantize_k else (k_pages, None)) - v_, v_scales = (_quantize(k_pages, quant_dtype) + v_, v_scales = (_quantize(v_pages, quant_dtype) if quantize_v else (v_pages, None)) o = paged_attention.paged_attention( From 04b3d823b700cafa3f37ba121b36104410e5e3fa Mon Sep 17 00:00:00 2001 From: Aleksei <208770786+Arech8@users.noreply.github.com> Date: Wed, 26 Nov 2025 17:19:56 +0200 Subject: [PATCH 10/43] Address LLVM assertion failure due to a multithreaded use. Update .gitignore (#563) When jaxlib was built in debug more, an assertion in LLVM code that lazy-loads VHLO dialect could fire, since the code path could execute in a multi-threaded environment, and LLVM dialect repositories aren't thread safe to modify. This patch applies the same changes that upstream makes to fix this: https://github.com/jax-ml/jax/commit/48c876227a67840d7bcd44d80d16edbc0e910335 (this includes disabling a call to `jax_mlir_ext.enter_multi_threaded_execution(context)` in `mlir.py`. Presumably, the whole functionality related to `enter_multi_threaded_execution()` multithreaded checks isn't ready yet, and it was prematurely rolled into the production code. Manual testing --- .gitignore | 4 ++++ jax/_src/interpreters/mlir.py | 5 ++--- jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc | 1 - 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index d30c019b31e8..e55a9a5de101 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,7 @@ jax.iml /include/ /lib/ /share/ + +/compile_commands.json +/strace.txt +/external diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 42d62340bd92..e25d4bbdf792 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -564,9 +564,8 @@ def make_ir_context() -> ir.Context: # multi threaded execution aborts the process if we try to register a new # dialect after this point. The dialect registry in a context is not thread # safe, and a fatal error is much better than a data race. - # jax_mlir_ext.enter_multi_threaded_execution(context) - # TODO(phawkins): clean up users who add their own dialects to JAX's contexts - # and enable this. + # if jaxlib_version >= (0, 8): + # jax_mlir_ext.enter_multi_threaded_execution(context) return context diff --git a/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc index f1bf60bb0517..65ee1673ca7c 100644 --- a/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc +++ b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc @@ -204,7 +204,6 @@ NB_MODULE(_jax_mlir_ext, m) { unwrap(registry)->insert(); unwrap(registry)->insert(); unwrap(registry)->insert(); - // For Mosaic GPU REGISTER_DIALECT(cf); REGISTER_DIALECT(gpu); From d9103b3aa05aa00a10d399af0e9c79df66cd13c8 Mon Sep 17 00:00:00 2001 From: Aleksei <208770786+Arech8@users.noreply.github.com> Date: Wed, 26 Nov 2025 18:35:19 +0200 Subject: [PATCH 11/43] Add skip of test_is_finite() on Cuda (#565) (forgot this skip in the previous PR) --- tests/pallas/ops_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e8628dda490c..28b669f22a41 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1003,6 +1003,10 @@ def test_is_finite(self, dtype): # The original test worked only on fp32@TPU, have no way to test CUDA self.skipTest("Not tested on CUDA, todo for the respective team") + if jtu.test_device_matches(["cuda"]): + self.skipTest("Not tested on CUDA") # set this b/c this how the test was + # originally configured. Have no way to test cuda. + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( From 3a63ef39ebf778e2b14d47c2020f596a3b06c8a4 Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Mon, 15 Dec 2025 08:00:42 -0800 Subject: [PATCH 12/43] Add rocm test requirements file (#570) --- build/rocm-test-requirements.txt | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 build/rocm-test-requirements.txt diff --git a/build/rocm-test-requirements.txt b/build/rocm-test-requirements.txt new file mode 100644 index 000000000000..399237175957 --- /dev/null +++ b/build/rocm-test-requirements.txt @@ -0,0 +1,24 @@ +absl-py +build +cloudpickle +colorama>=0.4.4 +filelock +flatbuffers +hypothesis +mpmath>=1.3 +pillow>=10.4.0 +# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t +portpicker; python_version<"3.13" +pytest-xdist +pytest-json-report +pytest-html +pytest-csv +pytest-rerunfailures +pytest-html-merger +pytest-reportlog +wheel +rich +setuptools +matplotlib +opt-einsum +auditwheel From 11a084baf4c457c08539f26cc3f1acad49a18322 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Mon, 15 Dec 2025 11:03:51 -0600 Subject: [PATCH 13/43] Let the unit tests use build.py for setting up Bazel commands for unit tests (#582) --- build/build.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/build/build.py b/build/build.py index 79de19508687..478ae0586c63 100755 --- a/build/build.py +++ b/build/build.py @@ -638,6 +638,10 @@ async def main(): ) if "rocm" in args.wheels: + if not args.configure_only: + print("ERROR: This repo is not used for building the ROCm JAX plugins. Please use the new plugin repo: https://github.com/ROCm/rocm-jax") + exit(1) + wheel_build_command_base.append("--config=rocm_base") wheel_build_command_base.append("--config=rocm") if clang_local: From 5a2d899d965e75edbf2182f68a7961bea70017ea Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Tue, 13 Jan 2026 10:29:56 -0600 Subject: [PATCH 14/43] adding abort logic to rocm/jax (#590) --- conftest.py | 177 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 176 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index fa0e6de94346..72b4b598891c 100644 --- a/conftest.py +++ b/conftest.py @@ -15,7 +15,10 @@ import os import pytest - +import json +import threading +import shutil +from datetime import datetime @pytest.fixture(autouse=True) def add_imports(doctest_namespace): @@ -72,3 +75,175 @@ def pytest_collection() -> None: os.environ.setdefault( "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) + +class ThreadSafeTestLogger: + """Thread-safe logging for parallel test execution and abort detection""" + def __init__(self): + self.locks = {} + self.global_lock = threading.Lock() + self.base_dir = os.path.abspath("./logs") + + # Create logs directory (archiving is handled by test runner scripts) + try: + os.makedirs(self.base_dir, exist_ok=True) + print(f"[TestLogger] Initialized log directory: {self.base_dir}") + except Exception as e: + print(f"[TestLogger] ERROR: Failed to create log directory {self.base_dir}: {e}") + # Fallback to temp directory if logs dir creation fails + import tempfile + self.base_dir = os.path.join(tempfile.gettempdir(), "jax_test_logs") + os.makedirs(self.base_dir, exist_ok=True) + print(f"[TestLogger] Using fallback directory: {self.base_dir}") + + def get_file_lock(self, test_file): + """Get or create a lock for a specific test file""" + with self.global_lock: + if test_file not in self.locks: + self.locks[test_file] = threading.Lock() + return self.locks[test_file] + + def get_test_file_name(self, session): + """Extract the test file name from the session""" + # Try to get from session config args + if hasattr(session, "config") and hasattr(session.config, "args"): + for arg in session.config.args: + # Handle full nodeid like "jax/tests/foo_test.py::TestClass::test_method" + if "tests/" in arg: + # Split on :: to get just the file path + file_path = arg.split("::")[0] + if file_path.endswith(".py"): + return os.path.basename(file_path).replace(".py", "") + + # Try to get from invocation params + if hasattr(session, "config") and hasattr(session.config, "invocation_params"): + invocation_dir = getattr(session.config.invocation_params, "dir", None) + if invocation_dir: + dir_name = os.path.basename(str(invocation_dir)) + if dir_name: + print(f"[TestLogger] Using invocation directory as test name: {dir_name}") + return dir_name + + # Last resort: try to get from session items + if hasattr(session, "items") and session.items: + first_item = session.items[0] + if hasattr(first_item, "fspath"): + fspath = str(first_item.fspath) + if ".py" in fspath: + return os.path.basename(fspath).replace(".py", "") + + print(f"[TestLogger] WARNING: Could not determine test file name, using 'unknown_test'") + print(f"[TestLogger] Session config args: {getattr(session.config, 'args', 'N/A')}") + return "unknown_test" + + def log_running_test(self, test_file, test_name, nodeid, start_time): + """Log the currently running test for abort detection""" + lock = self.get_file_lock(test_file) + with lock: + log_data = { + "test_file": test_file, + "test_name": test_name, + "nodeid": nodeid, + "start_time": start_time, + "status": "running", + "pid": os.getpid(), + "gpu_id": os.environ.get("HIP_VISIBLE_DEVICES", "unknown"), + } + + log_file = f"{self.base_dir}/{test_file}_last_running.json" + try: + # Ensure directory still exists (might have been deleted) + os.makedirs(self.base_dir, exist_ok=True) + with open(log_file, "w") as f: + json.dump(log_data, f, indent=2) + except Exception as e: + print(f"[TestLogger] ERROR: Failed to write running test log to {log_file}: {e}") + print(f"[TestLogger] Current working directory: {os.getcwd()}") + print(f"[TestLogger] Base directory: {self.base_dir}") + print(f"[TestLogger] Base directory exists: {os.path.exists(self.base_dir)}") + raise + + def clear_running_test(self, test_file): + """Clear the running test log when test completes successfully""" + lock = self.get_file_lock(test_file) + with lock: + log_file = f"{self.base_dir}/{test_file}_last_running.json" + if os.path.exists(log_file): + os.remove(log_file) + + +# Global logger instance +test_logger = ThreadSafeTestLogger() + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_protocol(item, nextitem): + """Hook that wraps around each test to track running tests for crash detection. + + This creates a "last_running" file before each test starts and deletes it + when the test completes successfully. If the test crashes, the file remains + and can be detected by the test runner. + """ + test_file = test_logger.get_test_file_name(item.session) + test_name = item.name + nodeid = item.nodeid + start_time = datetime.now().isoformat() + + # Log that this test is starting + try: + test_logger.log_running_test(test_file, test_name, nodeid, start_time) + except Exception as e: + print(f"[TestLogger] WARNING: Failed to log running test: {e}") + # Continue anyway - not critical for test execution + + test_completed = False + try: + outcome = yield + # Test completed (successfully or with normal failure) + test_completed = True + + # Clear the crash detection file + try: + test_logger.clear_running_test(test_file) + except Exception as e: + print(f"[TestLogger] WARNING: Failed to clear running test log: {e}") + + except Exception as e: + # Test raised exception (might be crash, might be normal exception) + print(f"[TestLogger] Test {test_name} exception: {e}") + if not test_completed: + # Don't clear the file - this might be a crash + print(f"[TestLogger] Leaving crash file for detection") + raise + + +@pytest.hookimpl(tryfirst=True) +def pytest_sessionstart(session): + """Called after the Session object has been created""" + gpu = os.environ.get('HIP_VISIBLE_DEVICES', '?') + print(f"Test session starting on GPU {gpu}") + + +@pytest.hookimpl(trylast=True) +def pytest_sessionfinish(session, exitstatus): + """Called after test run finished. + + If a crash file still exists, it means a test crashed and the runner + will detect it. We just report it here for visibility. + """ + test_file = test_logger.get_test_file_name(session) + log_file = f"{test_logger.base_dir}/{test_file}_last_running.json" + + if os.path.exists(log_file): + try: + with open(log_file, "r") as f: + abort_data = json.load(f) + print( + f"\n[CRASH DETECTED] {abort_data.get('nodeid', abort_data.get('test_name', 'unknown'))} " + f"(GPU: {abort_data.get('gpu_id', '?')}, PID: {abort_data.get('pid', '?')})" + ) + print(f"[CRASH DETECTED] Crash file will be processed by test runner") + except Exception as e: + print(f"[TestLogger] WARNING: Crash file exists but unreadable: {e}") + else: + # Normal completion - no crash + pass From becc59eebc10b9645a647ccb43c12f5e2908a63f Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Wed, 14 Jan 2026 19:57:26 +0200 Subject: [PATCH 15/43] Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (#597) --- tests/pallas/ops_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 28b669f22a41..7e46a2dc6753 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1007,6 +1007,9 @@ def test_is_finite(self, dtype): self.skipTest("Not tested on CUDA") # set this b/c this how the test was # originally configured. Have no way to test cuda. + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.8.0") + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( @@ -1057,6 +1060,9 @@ def test_is_finite_scalar(self, dtype): # The original test worked only on fp32@TPU, have no way to test CUDA self.skipTest("Not tested on CUDA, todo for the respective team") + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.8.0") + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( From cbfb842220498aa6ef618d909decd910329a5e57 Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Wed, 14 Jan 2026 19:57:53 +0200 Subject: [PATCH 16/43] Fix shared memory limit check for ROCm in test_dot (#596) --- tests/pallas/ops_test.py | 50 +++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 7e46a2dc6753..707d89f38b9f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -74,6 +74,34 @@ def is_power_of_two(n: int) -> bool: return (n > 0) and (n & (n - 1) == 0) +def get_rocm_shared_memory_limit() -> int: + """Get the shared memory (LDS) limit in bytes for ROCm devices. + + Queries rocminfo to get the GROUP segment size dynamically. + Returns 64KB as default if rocminfo fails (MI100/MI200/MI300 all have 64KB LDS). + """ + try: + result = subprocess.run( + ['rocminfo'], capture_output=True, text=True, timeout=10 + ) + if result.returncode != 0: + return 64 * 1024 # Default if rocminfo fails + lines = result.stdout.split('\n') + for i, line in enumerate(lines): + if 'Segment:' in line and 'GROUP' in line: + if i + 1 < len(lines): + size_line = lines[i + 1] + # Match "Size: () KB" with case-insensitive KB check + match = re.search(r'Size:\s+(\d+)\s*\([^)]+\)\s*KB', size_line, re.IGNORECASE) + if match: + size_kb = int(match.group(1)) + return size_kb * 1024 # Convert KB to bytes + except Exception: + pass + # Default for AMD GPUs (MI100/MI200/MI300 all have 64KB LDS) + return 64 * 1024 + + def smem_on_tpu(): if jtu.test_device_matches(["tpu"]): return pltpu.SMEM @@ -2095,12 +2123,22 @@ def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): if jtu.test_device_matches(["gpu"]): if dtype == jnp.bfloat16: self.skipTest("bfloat16 type are not supported on GPU") - if ( - math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) - > (256 * 256) * 2 - ): - self.skipTest("Shared memory size limit exceeded") - if (jax.local_devices()[0].shared_memory_per_block_optin == 99 * 1024 and + # Check shared memory limit: Triton loads lhs + rhs into shared memory + if jtu.is_device_rocm(): + # ROCm: use correct formula with dynamic limit from rocminfo + dtype_size = jnp.dtype(dtype).itemsize + shared_mem_bytes = (math.prod(lhs_shape) + math.prod(rhs_shape)) * dtype_size + shared_mem_limit = get_rocm_shared_memory_limit() + if shared_mem_bytes > shared_mem_limit: + self.skipTest("Shared memory size limit exceeded") + else: + # NVIDIA: keep original check + if ( + math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) + > (256 * 256) * 2 + ): + self.skipTest("Shared memory size limit exceeded") + if (jax.local_devices()[0].device_kind == "NVIDIA L4" and dtype == jnp.float32 and lhs_and_rhs_shape in [ ((128, 16), (128, 256)), From c979baff9d624a2a72d2379e7a99b86564c40498 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Wed, 14 Jan 2026 12:14:02 -0600 Subject: [PATCH 17/43] Fix Numpy signatures test (#598) Co-authored-by: Daniel Suo Co-authored-by: Jake VanderPlas --- tests/lax_numpy_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9ce0623c8b91..3a11614c337a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6390,6 +6390,10 @@ def testWrappedSignaturesMatch(self): 'stack': ['casting'], 'tri': ['like'], 'unravel_index': ['order'], +<<<<<<< HEAD +======= + 'var': ['mean'], +>>>>>>> a3f8af53f (Fix Numpy signatures test (#598)) 'vstack': ['casting'], 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] From e5fddf8a1b7341c7d56d2cb67caef4923c7dffef Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Sun, 18 Jan 2026 10:18:54 -0600 Subject: [PATCH 18/43] fix merge arts --- tests/lax_numpy_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 3a11614c337a..e1bc3d29f047 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6390,10 +6390,7 @@ def testWrappedSignaturesMatch(self): 'stack': ['casting'], 'tri': ['like'], 'unravel_index': ['order'], -<<<<<<< HEAD -======= 'var': ['mean'], ->>>>>>> a3f8af53f (Fix Numpy signatures test (#598)) 'vstack': ['casting'], 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] From ce0783a5194adf355f2ab8d348e4d5f572e53f81 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Thu, 22 Jan 2026 15:44:41 -0600 Subject: [PATCH 19/43] Enable RngShardingTests (#644) --- tests/array_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/array_test.py b/tests/array_test.py index 61c30b9ed065..970b93050de4 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1591,7 +1591,7 @@ class RngShardingTest(jtu.JaxTestCase): # tests that the PRNGs are automatically sharded as expected @parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5)) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_1d(self, num_devices): @jax.jit def f(x): @@ -1625,7 +1625,7 @@ def f(x): "mesh_shape": mesh_shape, "pspec": pspec} for mesh_shape in [(3, 2), (4, 2), (2, 3)] for pspec in [P('x', None), P(None, 'y'), P('x', 'y')]) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec): @jax.jit def f(x): From 472b2279929215d6a3f109ce7cd3285cd8d56942 Mon Sep 17 00:00:00 2001 From: Marco Minutoli Date: Thu, 12 Feb 2026 15:18:40 -0800 Subject: [PATCH 20/43] Enable test_variadic_reduce_window on ROCm (#647) --- tests/lax_vmap_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 57fdfc4dda88..de260ec4ed93 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -759,6 +759,8 @@ def testSort(self, shape, dimension, arity, bdims, is_stable): # TODO Collapse # TODO Scatter + # b/183233858: variadic reduce-window not implemented on XLA:CUDA + @jtu.skip_on_devices("cuda") def test_variadic_reduce_window(self): # https://github.com/jax-ml/jax/discussions/9818 and # https://github.com/jax-ml/jax/issues/9837 From 69381b1ca3d72dc275f9591b8a592326069f0e43 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 21/43] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 58adf3a42cf1..b5bc7b511ec9 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. @@ -219,6 +220,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. @@ -592,6 +594,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. @@ -1047,6 +1050,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. From 2988d2ad067e4e79255b47af18ed52f87fa679b5 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 22/43] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index b5bc7b511ec9..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,11 +137,8 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") csr_matmul = sparse_csr._csr_matvec if len(bshape) == 1 else sparse_csr._csr_matmat tol = {np.float32: 2E-5, np.float64: 1E-12, np.complex64: 1E-5, @@ -220,11 +217,8 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") op = lambda M: M.T if transpose else M @@ -594,11 +588,8 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") tol = {np.float32: 2E-5, np.float64: 2E-14} @@ -1050,11 +1041,8 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) From 3f3518fce5d6e6e4b7c9f860918adde14178c585 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 23/43] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 2045fd0d7bba473993e122dfdbc02b0a9dedbf31 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 24/43] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From b798b3b38a875d59a09bd5c00dce01c3596293d9 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 25/43] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From c7ccce22e2f6fc615f8d5e0b83f5da1f4a5786c5 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 26/43] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 2e76caea187ec0bd0b68ca65163ef9e22e7ab191 Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Wed, 28 Jan 2026 14:00:42 -0800 Subject: [PATCH 27/43] Enable testMultivariateNormalSingularCovariance on ROCm (#666) --- tests/random_lax_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 1842da7c1b25..44314e9034fe 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -963,7 +963,7 @@ def testMultivariateNormalCovariance(self): check_dtypes=False) @jtu.sample_product(method=['cholesky', 'eigh', 'svd']) - @jtu.skip_on_devices('cuda', 'tpu') # Some NaNs on accelerators. + @jtu.skip_on_devices('cuda', 'tpu') # Some NaNs on accelerators. ROCm supported def testMultivariateNormalSingularCovariance(self, method): # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293 mu = jnp.zeros((2,)) From 25bc14dd030d1594f045484927f5aa1b3416dc88 Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Wed, 28 Jan 2026 14:02:39 -0800 Subject: [PATCH 28/43] Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors (#668) --- tests/linalg_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 63c8a28bb76b..754cf5b09e64 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2371,6 +2371,7 @@ def testSelect(self, dtype): @jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)], dtype=float_types + complex_types) + @jtu.skip_on_devices("rocm") # Numerical errors on ROCm def test_tridiagonal_solve(self, shape, dtype): if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") From 55c86bc935c7060b36683451d3befb34099b1aa8 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Wed, 28 Jan 2026 16:38:17 -0600 Subject: [PATCH 29/43] Update Skip Reason Outputs (#663) --- jax/_src/test_util.py | 20 ++++++++++++++++---- tests/pallas/gpu_pallas_distributed_test.py | 14 ++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 10d65eb78cf6..b9ed2fa67f05 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -623,7 +623,14 @@ def skip_on_devices(*disabled_devices, skip_reason=None): skip_reason: Optional custom skip message when test is skipped. """ if skip_reason is None: - skip_reason = "Skipped on devices with tags: " + ", ".join(disabled_devices) + skip_messages = { + ("gpu",): "Skipped on all GPUs.", + ("cpu",): "Skipped on CPU.", + ("tpu",): "Skipped on TPU.", + ("cuda",): "Skipped on CUDA GPUs.", + ("rocm",): "Skipped on ROCm GPUs.", + } + skip_reason = skip_messages.get(disabled_devices) return _device_filter(lambda: not test_device_matches(disabled_devices), skip_reason) def run_on_devices(*enabled_devices, skip_reason=None): @@ -634,9 +641,14 @@ def run_on_devices(*enabled_devices, skip_reason=None): skip_reason: Optional custom skip message when test is skipped. """ if skip_reason is None: - skip_reason = ( - "Skipped unless running on devices with tags: " + ", ".join(enabled_devices) - ) + device_specific_skip_reasons = { + ("cpu",): "Skipped: CPU-only test.", + ("tpu",): "Skipped: TPU-only test.", + ("gpu",): "Skipped: GPU-only test.", + ("rocm",): "Skipped: ROCm-only test.", + ("cuda",): "Skipped: CUDA-only test.", + } + skip_reason = device_specific_skip_reasons.get(enabled_devices) return _device_filter(lambda: test_device_matches(enabled_devices), skip_reason) def device_supports_buffer_donation(): diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index 610965cb8429..dce9fc329187 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -51,20 +51,22 @@ def setUp(self): if jtu.test_device_matches(["rocm"]): self.skipTest("Mosaic not supported on ROCm currently.") + # Check mosaic support first (before GPU capability check) + if not mgpu.supports_cross_device_collectives(): + if jtu.test_device_matches(["rocm"]): + self.skipTest("Mosaic not supported on ROCm currently.") + else: + self.skipTest("NVSHMEM library unavailable.") if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") - if not mgpu.supports_cross_device_collectives(): - self.skipTest( - "Skip test since cross-device collectives are not supported" - " (either NVSHMEM is not available in multi-process mode, or mixed" - " mode is used).") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": self.skipTest("NVSHMEM doesn't work with the platform allocator.") super().setUp() - class PallasCallRemoteDMATest(TestCase): def test_remote_dma_basic(self): From 87d71117a5cb7969f3f5ae7852d77e498209ddec Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 30/43] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 3f26dfeb308f9e8f068cd11f5aaa641a9f6d67c0 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 31/43] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 9df214c0879199cb69855b32b1a6c16a98e57610 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Thu, 29 Jan 2026 13:12:07 -0600 Subject: [PATCH 32/43] Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677) --- tests/array_interoperability_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index d033a490f547..b14e9171f952 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -312,7 +312,7 @@ def testCudaArrayInterfaceOnNonCudaFails(self): self.assertFalse(hasattr(x, "__cuda_array_interface__")) with self.assertRaisesRegex( AttributeError, - "__cuda_array_interface__ is only defined for .*GPU buffers.", + "__cuda_array_interface__ is only defined for GPU buffers.", ): _ = x.__cuda_array_interface__ From 73824693e474b03b413cca7b85a5a16cf33a6747 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 33/43] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From f3efc4d317c0189e00f0de6e9d3dbc98c1f8c8d1 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 34/43] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 991bc2fb6efdf958b836e60f034ec0723760b5e3 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 35/43] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From abfda2d5657dcef930c70af39f1c2eedd1bf2442 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 36/43] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 30f61f4c0a0c7d7b1bbcd9b38884d430ec421ccf Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Thu, 5 Feb 2026 10:50:39 -0800 Subject: [PATCH 37/43] Add ROCm encoding for test_struct_encoding_determinism (#683) --- tests/experimental_rnn_test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 4bb611dcd842..61e79c00a09e 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -179,7 +179,7 @@ def f(weights, x, h_0, c_0): y_padded = y_ref[i, seq_lengths[i]:] np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded)) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_struct_encoding_determinism(self): def f(k1, k2, k3, k4): batch_size = 1 @@ -213,8 +213,15 @@ def f(k1, k2, k3, k4): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', - stablehlo) + # Platform-specific binary encodings for RnnDescriptor + cuda_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"' + rocm_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\008\\00\\00\\00\\00\\00\\00\\00\\1C\\00\\00\\00\\00\\00\\00\\00"' + + # Check that one of the expected encodings is present + if jtu.test_device_matches(["cuda"]): + self.assertIn(cuda_encoding, stablehlo) + elif jtu.test_device_matches(["rocm"]): + self.assertIn(rocm_encoding, stablehlo) # Note: Other LSTM tests that use `bidirectional=True` on ROCm are skipped # because of current numerical issues (as of ROCm 7.1.1). However, this From c63267880c8d7d53f3e15247489bc7bc54edcbe3 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 6 Feb 2026 12:48:48 -0600 Subject: [PATCH 38/43] Remove 'mean' from unsupported params for jnp.var (#689) --- tests/lax_numpy_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e1bc3d29f047..9ce0623c8b91 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6390,7 +6390,6 @@ def testWrappedSignaturesMatch(self): 'stack': ['casting'], 'tri': ['like'], 'unravel_index': ['order'], - 'var': ['mean'], 'vstack': ['casting'], 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] From cd072cd8d4ff447d446b98ec57f08544533723c7 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 6 Feb 2026 16:49:53 -0600 Subject: [PATCH 39/43] Implement approx_tanh for ROCm using OCML tanh function (#691) --- jax/_src/pallas/triton/primitives.py | 92 ++++++++++++++++++++++++++++ tests/pallas/ops_test.py | 41 +++++++++++++ 2 files changed, 133 insertions(+) diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8e763c3d8e6a..25423ebba471 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -48,6 +48,11 @@ def approx_tanh(x: jax.Array) -> jax.Array: elif x.dtype == jnp.float32: asm = "tanh.approx.f32 $0, $1;" constraint = "f" + elif x.dtype == jnp.float64: + # f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64) + # CUDA does not have a PTX instruction for f64 approximate tanh + asm = "tanh.approx.f64 $0, $1;" + constraint = "d" else: raise TypeError(f"approx_tanh does not accept {x.dtype} arrays") @@ -119,6 +124,13 @@ def _elementwise_inline_asm_lowering( result_shape_dtypes, ): del result_shape_dtypes # Unused. + + # For ROCm, PTX inline assembly is not supported. For tanh.approx, we use + # Triton's __triton_hip_fast_tanhf (fast exp-based formula) for f32, and + # OCML's __ocml_tanh_f64 for f64. See: https://github.com/triton-lang/triton/pull/7780 + if ctx.context.platform == "rocm" and "tanh.approx" in asm: + return _approx_tanh_rocm_lowering(ctx, *args) + return tt_dialect.ElementwiseInlineAsmOp( [*map(mlir.aval_to_ir_type, ctx.avals_out)], asm, @@ -129,6 +141,86 @@ def _elementwise_inline_asm_lowering( ).result +def _approx_tanh_rocm_lowering( + ctx: lowering.LoweringRuleContext, + *args, +): + """Lower approx_tanh for ROCm. + + AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction. + + For f32 (and f16/bf16 via casting): We use Triton's __triton_hip_fast_tanhf + which implements a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + See: https://github.com/triton-lang/triton/pull/7780 + + For f64: We use OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) + since fast_tanhf only supports f32. + """ + from jax._src.lib.mlir import ir + from jax._src.lib.mlir.dialects import arith as arith_dialect + + [arg] = args + [out_aval] = ctx.avals_out + in_dtype = ctx.avals_in[0].dtype + + # Helper to get IR type for a dtype + def dtype_to_ir_type(dtype): + dtype = jnp.dtype(dtype) + return mlir.dtype_to_ir_type(dtype) + + # f64: use __ocml_tanh_f64 (fast_tanhf only supports f32) + if in_dtype == jnp.float64: + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="", + libpath="", + symbol="__ocml_tanh_f64", + pure=True, + ) + return [result] + + # fast_tanhf only supports f32. For f16/bf16, cast to f32, compute, cast back. + needs_cast = in_dtype in (jnp.float16, jnp.bfloat16) + + if needs_cast: + # Cast input to f32 (extend) + f32_type = dtype_to_ir_type(jnp.float32) + if out_aval.shape: + f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type) + else: + f32_result_type = f32_type + arg_f32 = arith_dialect.extf(f32_result_type, arg) + + # Call __triton_hip_fast_tanhf (fast exp-based implementation) + tanh_result = tt_dialect.extern_elementwise( + f32_result_type, + [arg_f32], + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + # Cast result back to original dtype (truncate) + out_type = mlir.aval_to_ir_type(out_aval) + result = arith_dialect.truncf(out_type, tanh_result) + else: + # f32: call __triton_hip_fast_tanhf directly + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + return [result] + + def debug_barrier() -> None: """Synchronizes all kernel executions in the grid.""" return debug_barrier_p.bind() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 707d89f38b9f..b5a3de52dd49 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1910,6 +1910,47 @@ def kernel(o_ref): np.testing.assert_allclose(f(), kernel()) + @parameterized.parameters("float16", "bfloat16", "float32", "float64") + def test_approx_tanh(self, dtype): + self.skip_if_mosaic_gpu() + + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + if self.INTERPRET: + self.skipTest("approx_tanh is not supported in interpret mode") + + if (dtype == "bfloat16" and + jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + + if dtype == "float64": + if jtu.test_device_matches(["cuda"]): + self.skipTest("f64 approx_tanh is only supported on ROCm") + + # Enable x64 for f64 test if not already enabled, restore after test + original_x64 = jax.config.x64_enabled + if dtype == "float64" and not original_x64: + jax.config.update("jax_enable_x64", True) + self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu_triton.approx_tanh(x_ref[...]) + + x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/jax-ml/jax/issues/11014. + np.testing.assert_allclose( + kernel(x).astype(jnp.float32), + jnp.tanh(x).astype(jnp.float32), + atol=5e-3, + rtol=5e-3, + ) + @parameterized.parameters( ((2, 4), (8,)), ((2, 4), (8, 1)), From 90a3578d9ce25420cd051412adc9e9dbb5081223 Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Mon, 9 Feb 2026 05:08:48 -0800 Subject: [PATCH 40/43] Skipping testEighTinyNorm due to hipSolver issues (#697) --- tests/linalg_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 754cf5b09e64..5b38275947ff 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -570,7 +570,12 @@ def testEighZeroDiagonal(self): np.linalg.norm(np.matmul(a, v) - w * v), 2.5 * eps * np.linalg.norm(a) ) + def testEighTinyNorm(self): + if jtu.is_device_rocm(): + # numerical errors seen as of ROCm 7.2 due to hipSolver issue + # TODO: re-enable the test once the hipSolver issue is fixed + self.skipTest("testEighNorm not supported on ROCm due to hipSOLVER issue") rng = jtu.rand_default(self.rng()) a = rng((300, 300), dtype=np.float32) eps = jnp.finfo(a.dtype).eps From f7a64071bf90c77bbdc195ecf80353df19e6b64b Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Fri, 20 Feb 2026 11:17:09 -0600 Subject: [PATCH 41/43] Abort detection CI workflow (#688) --- .github/workflows/pytest_rocm_abort.yml | 162 ++++++++++++ .../wheel_tests_nightly_release_abort.yml | 53 ++++ ci/run_pytest_rocm_abort.sh | 179 ++++++++++++++ conftest.py | 233 +++++------------- 4 files changed, 456 insertions(+), 171 deletions(-) create mode 100644 .github/workflows/pytest_rocm_abort.yml create mode 100644 .github/workflows/wheel_tests_nightly_release_abort.yml create mode 100755 ci/run_pytest_rocm_abort.sh diff --git a/.github/workflows/pytest_rocm_abort.yml b/.github/workflows/pytest_rocm_abort.yml new file mode 100644 index 000000000000..fcfdb06855d5 --- /dev/null +++ b/.github/workflows/pytest_rocm_abort.yml @@ -0,0 +1,162 @@ +# CI - Pytest ROCm (Abort Support) +# +# This workflow runs the ROCm tests with Pytest in ROCm GHCR containers, +# using the ROCm `pytest-abort` retry wrapper to detect/retry aborts/crashes. +# +# It can be triggered manually via workflow_dispatch or called by other workflows +# via workflow_call. +# +# It consists of the following job: +# run-tests: +# - Runs in ROCm container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest) +# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from latest release. +# - Executes the `run_pytest_rocm_abort.sh` script, which installs wheel artifacts and +# runs the ROCm tests with Pytest under `pytest-abort-retry`. +name: CI - Pytest ROCm (Abort Support) + +on: + workflow_dispatch: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: choice + default: "linux-x86-64-4gpu-amd" + options: + - "linux-x86-64-1gpu-amd" + - "linux-x86-64-4gpu-amd" + - "linux-x86-64-8gpu-amd" + python: + description: "Which Python version to use?" + type: choice + default: "3.11" + options: + - "3.11" + - "3.12" + rocm-version: + description: "Which ROCm version to test?" + type: choice + default: "7.2.0" + options: + - "7.2.0" + rocm-tag: + description: "ROCm tag for container image (e.g., rocm720)" + type: string + default: "rocm720" + jaxlib-version: + description: "Which jaxlib version to use? (head/pypi_latest)" + type: choice + default: "head" + options: + - "head" + - "pypi_latest" + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)" + type: choice + default: '0' + options: + - '0' + - '1' + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + type: string + default: 'gs://jax-nightly-artifacts/latest' + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-64-4gpu-amd" + python: + description: "Which Python version to use?" + type: string + default: "3.11" + rocm-version: + description: "Which ROCm version to test?" + type: string + default: "7.2.0" + rocm-tag: + description: "ROCm tag for container image (e.g., rocm720)" + type: string + default: "rocm720" + jaxlib-version: + description: "Which jaxlib version to use? (head/pypi_latest)" + type: string + default: "head" + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + default: 'gs://jax-nightly-artifacts/latest' + type: string + +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + +jobs: + run-tests: + defaults: + run: + # Set the shell to bash as GitHub actions run with /bin/sh by default + shell: bash + runs-on: ${{ inputs.runner }} + continue-on-error: true + # Run in ROCm GHCR container with GPU access + container: + image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --shm-size 64G --env-file /etc/podinfo/gha-gpu-isolation-settings + name: "${{ (contains(inputs.runner, '1gpu') && '1gpu') || + (contains(inputs.runner, '4gpu') && '4gpu') || + (contains(inputs.runner, '8gpu') && '8gpu') }}, ROCm ${{ inputs.rocm-version }}, py${{ inputs.python }}" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_ENABLE_X64: "0" + + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX ROCm wheels + uses: ./.github/actions/download-jax-rocm-wheels + with: + python: ${{ inputs.python }} + rocm-version: ${{ inputs.rocm-version }} + jaxlib-version: ${{ inputs.jaxlib-version }} + skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Install Python dependencies + run: | + $JAXCI_PYTHON -m pip install uv~=0.5.30 + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Pytest ROCm tests (abort support) + timeout-minutes: 180 + run: ./ci/run_pytest_rocm_abort.sh + - name: Upload pytest results to artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: logs_abort + path: | + logs_abort/ + if-no-files-found: warn + retention-days: 2 + overwrite: true diff --git a/.github/workflows/wheel_tests_nightly_release_abort.yml b/.github/workflows/wheel_tests_nightly_release_abort.yml new file mode 100644 index 000000000000..518306956eff --- /dev/null +++ b/.github/workflows/wheel_tests_nightly_release_abort.yml @@ -0,0 +1,53 @@ +# CI - Wheel Tests (Nightly/Release) (ROCm abort only) +# +# This workflow runs only the ROCm wheel tests using the abort/retry wrapper workflow. +name: CI - Wheel Tests (Nightly/Release) (ROCm abort only) + +on: + workflow_dispatch: + inputs: + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" + required: true + default: 'gs://jax-nightly-artifacts/latest' + type: string + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection? (yes/no)' + required: false + default: 'no' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + +jobs: + run-pytest-rocm: + uses: ./.github/workflows/pytest_rocm_abort.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"] + python: ["3.11", "3.12", "3.13", "3.14"] + rocm: [ + {version: "7.2.0", tag: "rocm720"}, + ] + name: "Pytest ROCm abort (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm.version }} + rocm-tag: ${{ matrix.rocm.tag }} + jaxlib-version: "head" + skip-download-jaxlib-and-plugins-from-gcs: ${{inputs.skip-download-jaxlib-and-plugins-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} diff --git a/ci/run_pytest_rocm_abort.sh b/ci/run_pytest_rocm_abort.sh new file mode 100755 index 000000000000..f17e7184b3bf --- /dev/null +++ b/ci/run_pytest_rocm_abort.sh @@ -0,0 +1,179 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pytest ROCm tests (with ROCm pytest-abort retry wrapper). +# Requires the jaxlib and ROCm plugin wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +source ci/envs/default.env + +# Install jaxlib and ROCm plugin wheels inside the $JAXCI_OUTPUT_DIR directory +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Print all the installed packages +echo "Installed packages:" +"$JAXCI_PYTHON" -m uv pip freeze + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +rocm-smi + +# ============================================================================== +# Set up the generic test environment variables +# ============================================================================== +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true +export NCCL_DEBUG=WARN +export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_X64="$JAXCI_ENABLE_X64" + +# ============================================================================== +# Calculate the optimal number of parallel processes for pytest +# This will be the minimum of: GPU capacity, CPU core count, and a system RAM limit. +# ============================================================================== + +export gpu_count=$(rocminfo | egrep -c "Device Type:\\s+GPU") +echo "Number of GPUs detected: $gpu_count" + +# Query GPU 0 memory using rocm-smi +export memory_per_gpu_mib=$(rocm-smi -d 0 --showmeminfo vram | grep -i "vram total" | awk '{print int($NF/1024/1024)}' | head -1) +echo "Reported memory per GPU: $memory_per_gpu_mib MiB" + +# Convert effective memory from MiB to GiB. +export memory_per_gpu_gib=$((memory_per_gpu_mib / 1024)) +echo "Effective memory per GPU: $memory_per_gpu_gib GiB" + +# Allow 2 GiB of GPU RAM per test. +export max_tests_per_gpu=$((memory_per_gpu_gib / 2)) +echo "Max tests per GPU (assuming 2GiB/test): $max_tests_per_gpu" + +export num_processes=$((gpu_count * max_tests_per_gpu)) +echo "Initial number of processes based on GPU capacity: $num_processes" + +export num_cpu_cores=$(nproc) +echo "Number of CPU cores available: $num_cpu_cores" + +# Reads total memory from /proc/meminfo (in KiB) and converts to GiB. +export total_ram_gib=$(awk '/MemTotal/ {printf \"%.0f\", $2/1048576}' /proc/meminfo) +echo "Total system RAM: $total_ram_gib GiB" + +# Set a safety limit for system RAM usage, e.g., 1/6th of total. +export host_memory_limit=$((total_ram_gib / 6)) +echo "Host memory process limit (1/6th of total RAM): $host_memory_limit" + +if [[ $num_cpu_cores -lt $num_processes ]]; then + num_processes=$num_cpu_cores + echo "Adjusting num_processes to match CPU core count: $num_processes" +fi + +if [[ $host_memory_limit -lt $num_processes ]]; then + num_processes=$host_memory_limit + echo "Adjusting num_processes to match host memory limit: $num_processes" +fi + +if [[ 16 -lt $num_processes ]]; then + num_processes=16 + echo "Reducing num_processes to $num_processes" +fi + +echo "Final number of processes to run: $num_processes" + +export JAX_ENABLE_ROCM_XDIST="$gpu_count" +export XLA_PYTHON_CLIENT_ALLOCATOR=platform +export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1 --xla_gpu_enable_nccl_comm_splitting=false --xla_gpu_enable_command_buffer=" + +# Disable core dumps just in case +ulimit -c 0 + +# Keep deselected tests in one place for the abort wrapper. +ROCM_PYTEST_DESELECT_ARGS=( + --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data + --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices + --deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric +) + +# --max-runs: retry the entire pytest run up to N times on abort/crash. +# --max-worker-restart: restart crashed xdist workers up to N times. +# --maxfail: stop the run after N test failures. +rocm_test_cmd() { + local abort_flag="${1:-0}" + shift + if [[ "$abort_flag" == "1" ]]; then + pytest-abort-retry --max-runs 3 --clear-crash-log -- "$@" + else + "$@" + fi +} + +rocm_log_tail_on_failure() { + local logfile="$1" + local status="$2" + if [[ "$status" -ne 0 ]]; then + echo "Pytest failed (exit=$status). Showing last 200 lines of $logfile:" + tail -n 200 "$logfile" || true + else + echo "Pytest output saved to $logfile (uploaded as artifact)." + fi +} + +rocm_install_extra_requirements() { + if [[ -n "${GITHUB_WORKSPACE:-}" ]]; then + cd "$GITHUB_WORKSPACE" + fi + + # Install extra requirements. + "$JAXCI_PYTHON" -m uv pip install pytest-timeout pytest-html pytest-csv pytest-json-report pytest-abort +} + +rocm_install_extra_requirements + +echo "Running ROCm tests (with abort/retry wrapper)..." +mkdir -p logs_abort +logfile="logs_abort/jax_ToT_UT_abort.log" + +# pytest-abort output directories (must be set before running pytest). +export PYTEST_ABORT_LAST_RUNNING_DIR="logs_abort/last_running" +export PYTEST_ABORT_CRASHED_TESTS_LOG="logs_abort/crashed_tests.jsonl" +mkdir -p "$PYTEST_ABORT_LAST_RUNNING_DIR" + +set +e +rocm_test_cmd 1 "$JAXCI_PYTHON" -m pytest -n "$num_processes" --max-worker-restart=200 --tb=short --timeout=1200 --timeout-method=thread tests \ + "${ROCM_PYTEST_DESELECT_ARGS[@]}" \ + --json-report \ + --json-report-file=logs_abort/tests-report-abort.json \ + --csv=logs_abort/tests-report-abort.csv \ + --html=logs_abort/tests-report-abort.html \ + --self-contained-html \ + >"$logfile" 2>&1 +pytest_status=$? +set -e +rocm_log_tail_on_failure "$logfile" "$pytest_status" + +echo "Postprocessing reports with crashed tests..." +pytest-abort-postprocess \ + --crash-log "$PYTEST_ABORT_CRASHED_TESTS_LOG" \ + --json-report logs_abort/tests-report-abort.json \ + --html-report logs_abort/tests-report-abort.html \ + --csv-report logs_abort/tests-report-abort.csv \ + >>"$logfile" 2>&1 + +exit "$pytest_status" diff --git a/conftest.py b/conftest.py index 72b4b598891c..ca144a6dcee4 100644 --- a/conftest.py +++ b/conftest.py @@ -15,10 +15,6 @@ import os import pytest -import json -import threading -import shutil -from datetime import datetime @pytest.fixture(autouse=True) def add_imports(doctest_namespace): @@ -76,174 +72,69 @@ def pytest_collection() -> None: "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) -class ThreadSafeTestLogger: - """Thread-safe logging for parallel test execution and abort detection""" - def __init__(self): - self.locks = {} - self.global_lock = threading.Lock() - self.base_dir = os.path.abspath("./logs") - - # Create logs directory (archiving is handled by test runner scripts) - try: - os.makedirs(self.base_dir, exist_ok=True) - print(f"[TestLogger] Initialized log directory: {self.base_dir}") - except Exception as e: - print(f"[TestLogger] ERROR: Failed to create log directory {self.base_dir}: {e}") - # Fallback to temp directory if logs dir creation fails - import tempfile - self.base_dir = os.path.join(tempfile.gettempdir(), "jax_test_logs") - os.makedirs(self.base_dir, exist_ok=True) - print(f"[TestLogger] Using fallback directory: {self.base_dir}") - - def get_file_lock(self, test_file): - """Get or create a lock for a specific test file""" - with self.global_lock: - if test_file not in self.locks: - self.locks[test_file] = threading.Lock() - return self.locks[test_file] - - def get_test_file_name(self, session): - """Extract the test file name from the session""" - # Try to get from session config args - if hasattr(session, "config") and hasattr(session.config, "args"): - for arg in session.config.args: - # Handle full nodeid like "jax/tests/foo_test.py::TestClass::test_method" - if "tests/" in arg: - # Split on :: to get just the file path - file_path = arg.split("::")[0] - if file_path.endswith(".py"): - return os.path.basename(file_path).replace(".py", "") - - # Try to get from invocation params - if hasattr(session, "config") and hasattr(session.config, "invocation_params"): - invocation_dir = getattr(session.config.invocation_params, "dir", None) - if invocation_dir: - dir_name = os.path.basename(str(invocation_dir)) - if dir_name: - print(f"[TestLogger] Using invocation directory as test name: {dir_name}") - return dir_name - - # Last resort: try to get from session items - if hasattr(session, "items") and session.items: - first_item = session.items[0] - if hasattr(first_item, "fspath"): - fspath = str(first_item.fspath) - if ".py" in fspath: - return os.path.basename(fspath).replace(".py", "") - - print(f"[TestLogger] WARNING: Could not determine test file name, using 'unknown_test'") - print(f"[TestLogger] Session config args: {getattr(session.config, 'args', 'N/A')}") - return "unknown_test" - - def log_running_test(self, test_file, test_name, nodeid, start_time): - """Log the currently running test for abort detection""" - lock = self.get_file_lock(test_file) - with lock: - log_data = { - "test_file": test_file, - "test_name": test_name, - "nodeid": nodeid, - "start_time": start_time, - "status": "running", - "pid": os.getpid(), - "gpu_id": os.environ.get("HIP_VISIBLE_DEVICES", "unknown"), - } - - log_file = f"{self.base_dir}/{test_file}_last_running.json" - try: - # Ensure directory still exists (might have been deleted) - os.makedirs(self.base_dir, exist_ok=True) - with open(log_file, "w") as f: - json.dump(log_data, f, indent=2) - except Exception as e: - print(f"[TestLogger] ERROR: Failed to write running test log to {log_file}: {e}") - print(f"[TestLogger] Current working directory: {os.getcwd()}") - print(f"[TestLogger] Base directory: {self.base_dir}") - print(f"[TestLogger] Base directory exists: {os.path.exists(self.base_dir)}") - raise - - def clear_running_test(self, test_file): - """Clear the running test log when test completes successfully""" - lock = self.get_file_lock(test_file) - with lock: - log_file = f"{self.base_dir}/{test_file}_last_running.json" - if os.path.exists(log_file): - os.remove(log_file) - - -# Global logger instance -test_logger = ThreadSafeTestLogger() - - -@pytest.hookimpl(hookwrapper=True) -def pytest_runtest_protocol(item, nextitem): - """Hook that wraps around each test to track running tests for crash detection. - - This creates a "last_running" file before each test starts and deletes it - when the test completes successfully. If the test crashes, the file remains - and can be detected by the test runner. - """ - test_file = test_logger.get_test_file_name(item.session) - test_name = item.name - nodeid = item.nodeid - start_time = datetime.now().isoformat() - - # Log that this test is starting - try: - test_logger.log_running_test(test_file, test_name, nodeid, start_time) - except Exception as e: - print(f"[TestLogger] WARNING: Failed to log running test: {e}") - # Continue anyway - not critical for test execution - - test_completed = False - try: - outcome = yield - # Test completed (successfully or with normal failure) - test_completed = True - - # Clear the crash detection file - try: - test_logger.clear_running_test(test_file) - except Exception as e: - print(f"[TestLogger] WARNING: Failed to clear running test log: {e}") - - except Exception as e: - # Test raised exception (might be crash, might be normal exception) - print(f"[TestLogger] Test {test_name} exception: {e}") - if not test_completed: - # Don't clear the file - this might be a crash - print(f"[TestLogger] Leaving crash file for detection") - raise + elif num_rocm_devices := os.environ.get("JAX_ENABLE_ROCM_XDIST", None): + num_rocm_devices = int(num_rocm_devices) + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + assigned = str(xdist_worker_number % num_rocm_devices) + # If ROCR_VISIBLE_DEVICES is set, don't also set HIP_VISIBLE_DEVICES + # (double-filtering can produce HIP_ERROR_NoDevice). Respect the outer setting. + if os.environ.get("ROCR_VISIBLE_DEVICES"): + return -@pytest.hookimpl(tryfirst=True) -def pytest_sessionstart(session): - """Called after the Session object has been created""" - gpu = os.environ.get('HIP_VISIBLE_DEVICES', '?') - print(f"Test session starting on GPU {gpu}") + # If present-but-empty, this can hide all GPUs. + if os.environ.get("HIP_VISIBLE_DEVICES", None) == "": + del os.environ["HIP_VISIBLE_DEVICES"] + + # HIP layer isolation (ROCm also accepts CUDA_VISIBLE_DEVICES, but we avoid it here). + os.environ["HIP_VISIBLE_DEVICES"] = assigned + +def pytest_configure(config) -> None: + # Real pytest hook (runs early in main + each xdist worker). + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") or "main" + + # xdist master: print planned mapping (worker stdout is often hidden) + numproc = int(getattr(getattr(config, "option", None), "numprocesses", 0) or 0) + if xdist_worker_name == "main" and numproc > 0: + hip0 = (os.environ.get("HIP_VISIBLE_DEVICES") or "").strip() + cuda_x = (os.environ.get("JAX_ENABLE_CUDA_XDIST") or "").strip() + tpu_x = (os.environ.get("JAX_ENABLE_TPU_XDIST") or "").strip() + rocm_x = (os.environ.get("JAX_ENABLE_ROCM_XDIST") or "").strip() + if cuda_x: + try: + ndev = int(cuda_x) + except ValueError: + ndev = 0 + if ndev > 0: + mapping = ", ".join(f"gw{i}->CUDA_VISIBLE_DEVICES={i % ndev}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif tpu_x: + mapping = ", ".join(f"gw{i}->TPU_VISIBLE_CHIPS={i}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif rocm_x: + try: + ndev = int(rocm_x) + except ValueError: + ndev = 0 + if ndev > 0: + mapping = ", ".join(f"gw{i}->HIP_VISIBLE_DEVICES={i % ndev}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif hip0: + print(f"[DeviceVisibility] master HIP_VISIBLE_DEVICES={hip0}", flush=True) + if os.environ.get("JAX_ENABLE_TPU_XDIST", None): + if xdist_worker_name.startswith("gw"): + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") -@pytest.hookimpl(trylast=True) -def pytest_sessionfinish(session, exitstatus): - """Called after test run finished. - - If a crash file still exists, it means a test crashed and the runner - will detect it. We just report it here for visibility. - """ - test_file = test_logger.get_test_file_name(session) - log_file = f"{test_logger.base_dir}/{test_file}_last_running.json" - - if os.path.exists(log_file): - try: - with open(log_file, "r") as f: - abort_data = json.load(f) - print( - f"\n[CRASH DETECTED] {abort_data.get('nodeid', abort_data.get('test_name', 'unknown'))} " - f"(GPU: {abort_data.get('gpu_id', '?')}, PID: {abort_data.get('pid', '?')})" - ) - print(f"[CRASH DETECTED] Crash file will be processed by test runner") - except Exception as e: - print(f"[TestLogger] WARNING: Crash file exists but unreadable: {e}") - else: - # Normal completion - no crash - pass + elif num_cuda_devices := os.environ.get("JAX_ENABLE_CUDA_XDIST", None): + if xdist_worker_name.startswith("gw"): + num_cuda_devices = int(num_cuda_devices) + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault( + "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) + ) From 16044530cc33416f45beada9b43d01f6029e80fb Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Sat, 21 Feb 2026 01:01:55 -0600 Subject: [PATCH 42/43] Implement Mosaic GPU detection and Auto-Skips Add functions to detect and manage Mosaic GPU usage in tests and auto-skip them on ROCm. --- conftest.py | 183 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 182 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index ca144a6dcee4..ecd884dad01d 100644 --- a/conftest.py +++ b/conftest.py @@ -16,6 +16,181 @@ import os import pytest +# Mosaic GPU checking based on test *file path* only (avoid test-name substrings). +_MOSAIC_GPU_PATH_NEEDLES = ( + f"{os.sep}tests{os.sep}mosaic{os.sep}", + f"{os.sep}tests{os.sep}pallas{os.sep}mgpu_", + f"{os.sep}tests{os.sep}pallas{os.sep}mosaic_gpu", + f"{os.sep}tests{os.sep}pallas{os.sep}mosaic", +) + +# Simple Mosaic GPU *usage* substring checks (avoid import-only signals). +_MOSAIC_GPU_SOURCE_NEEDLES = ( + "inline_mgpu", + "plgpu_mgpu.", + "mosaic_gpu_interpret", + "mosaic_gpu_backend", + "jax.experimental.mosaic.gpu", # runtime usage in body (not module import scan) + "jax.experimental.pallas.mosaic_gpu", +) + + +def _pallas_defaults_to_mosaic_gpu() -> bool: + """Returns True if Pallas GPU lowering defaults to Mosaic GPU.""" + try: + from jax._src.pallas import pallas_call as pallas_call_lib # pytype: disable=import-error + return bool(pallas_call_lib._PALLAS_USE_MOSAIC_GPU.value) # pylint: disable=protected-access + except Exception: + return False + + +def _running_on_rocm() -> bool: + """Best-effort ROCm detection. + + First tries to check rocm in jaxlib version, falls back to checking backend + platform_version so that it works for ROCm PJRT plugin installs where jaxlib's + version tag may not contain rocm. + """ + try: + import jaxlib.version as jaxlib_version # pytype: disable=import-error + version_str = getattr(jaxlib_version, "__version__", "") + except Exception: + version_str = "" + if "rocm" in version_str.lower(): + return True + try: + import jax # pytype: disable=import-error + from jax._src import xla_bridge # pytype: disable=import-error + backend = xla_bridge.get_backend() + pv = getattr(backend, "platform_version", "") or "" + return "rocm" in str(pv).lower() + except Exception: + return False + + +def _source_mentions_mosaic_gpu(src: str) -> bool: + """Returns True if the test file content has Mosaic GPU usage.""" + lowered = src.lower() + return any(n in lowered for n in _MOSAIC_GPU_SOURCE_NEEDLES) + + +def _looks_like_mosaic_gpu_path(path_str: str) -> bool: + """Returns True if the path is a Mosaic-GPU-only test file.""" + lowered = path_str.lower() + return any(n.lower() in lowered for n in _MOSAIC_GPU_PATH_NEEDLES) + + +def _class_mosaic_override(cls: type | None, cache: dict[object, object]) -> bool | None: + """Detects explicit class-level Mosaic enable/disable. + + Returns: + - True if the class forces Mosaic GPU (`_PALLAS_USE_MOSAIC_GPU(True)`). + - False if it forces Triton (`_PALLAS_USE_MOSAIC_GPU(False)`). + - None if no explicit override is found. + """ + if cls is None: + return None + cache_key = ("__mosaic_override__", cls) + if cache_key in cache: + return cache[cache_key] # type: ignore[return-value] + import inspect + try: + src = inspect.getsource(cls).lower() + except Exception: + cache[cache_key] = None + return None + if "_pallas_use_mosaic_gpu(true" in src: + cache[cache_key] = True + return True + if "_pallas_use_mosaic_gpu(false" in src: + cache[cache_key] = False + return False + cache[cache_key] = None + return None + + +def _is_mosaic_gpu_item( + item: pytest.Item, + cache: dict[object, bool], + *, + running_on_rocm: bool, + pallas_defaults_to_mosaic: bool, +) -> bool: + """Returns True if this test item uses (or would use) Mosaic GPU.""" + path_obj = getattr(item, "path", None) or getattr(item, "fspath", None) + path_str = str(path_obj) if path_obj is not None else "" + if _looks_like_mosaic_gpu_path(path_str): + return True + + import inspect + + obj = getattr(item, "obj", None) + if obj is None: + return False + if obj in cache: + return cache[obj] + try: + src = inspect.getsource(obj) + except Exception: + cache[obj] = False + return False + + lowered = src.lower() + # Direct Mosaic usage in the test function/method. + if _source_mentions_mosaic_gpu(lowered): + cache[obj] = True + return True + + # Respect explicit class-level override: if a test class forces Mosaic off, + # we should not skip it just because Pallas defaults to Mosaic elsewhere. + cls_override = _class_mosaic_override(getattr(item, "cls", None), cache) # type: ignore[arg-type] + if cls_override is False: + cache[obj] = False + return False + if cls_override is True: + cache[obj] = True + return True + + # Implicit Mosaic usage: on ROCm, `pallas_call` defaults to Mosaic GPU when + # `compiler_params` is not specified and Mosaic is the default backend. + if running_on_rocm and pallas_defaults_to_mosaic: + uses_pallas_call = ( + ".pallas_call" in lowered + or "pl.pallas_call" in lowered + or "pallas_call(" in lowered + ) + explicitly_selects_compiler = "compiler_params=" in lowered + if uses_pallas_call and not explicitly_selects_compiler: + cache[obj] = True + return True + + cache[obj] = False + return False + + +def pytest_collection_modifyitems( + config: pytest.Config, items: list[pytest.Item] +) -> None: + """Mark Mosaic GPU tests and skip them on ROCm.""" + running_on_rocm = _running_on_rocm() + pallas_defaults_to_mosaic = _pallas_defaults_to_mosaic_gpu() if running_on_rocm else False + cache: dict[object, bool] = {} + for item in items: + is_mosaic_gpu = _is_mosaic_gpu_item( + item, + cache, + running_on_rocm=running_on_rocm, + pallas_defaults_to_mosaic=pallas_defaults_to_mosaic, + ) + if not is_mosaic_gpu: + continue + item.add_marker(pytest.mark.mosaic_gpu) + if running_on_rocm: + item.add_marker(pytest.mark.skip( + reason="Mosaic GPU tests are not supported on ROCm" + )) + + @pytest.fixture(autouse=True) def add_imports(doctest_namespace): import jax @@ -92,7 +267,13 @@ def pytest_collection() -> None: # HIP layer isolation (ROCm also accepts CUDA_VISIBLE_DEVICES, but we avoid it here). os.environ["HIP_VISIBLE_DEVICES"] = assigned -def pytest_configure(config) -> None: +def pytest_configure(config: pytest.Config) -> None: + """Register custom pytest markers and print attached GPUs to xdist workers.""" + config.addinivalue_line( + "markers", + "mosaic_gpu: tests that use Mosaic GPU (skipped on ROCm)", + ) + # Real pytest hook (runs early in main + each xdist worker). xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") or "main" From 95a4d4aed8309b978eb337b434b12d938de97244 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Sat, 21 Feb 2026 06:22:48 -0600 Subject: [PATCH 43/43] Fix indentation for adding pytest marker --- conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index ecd884dad01d..9578612e498e 100644 --- a/conftest.py +++ b/conftest.py @@ -269,7 +269,7 @@ def pytest_collection() -> None: def pytest_configure(config: pytest.Config) -> None: """Register custom pytest markers and print attached GPUs to xdist workers.""" - config.addinivalue_line( + config.addinivalue_line( "markers", "mosaic_gpu: tests that use Mosaic GPU (skipped on ROCm)", )