From 097ff78e1f6ccb9f349db81768a0a0f338666da9 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 12 Nov 2025 21:39:01 +0000 Subject: [PATCH 01/44] Remove nvidia_wheel_versions --- jaxlib/jax.bzl | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 2e2757f2ecc4..7989c6c3f4c7 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,7 +20,6 @@ load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") -load("@nvidia_wheel_versions//:versions.bzl", "NVIDIA_WHEEL_VERSIONS") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND") load("@rocm_external_test_deps//:external_deps.bzl", "EXTERNAL_DEPS") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") @@ -461,7 +460,6 @@ def _jax_wheel_impl(ctx): if ctx.attr.platform_version == "": fail("platform_version must be set to a valid cuda version for cuda wheels") args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels - args.add("--nvidia_wheel_versions_data", NVIDIA_WHEEL_VERSIONS) # required for gpu wheels if ctx.attr.enable_rocm: args.add("--enable-rocm", "True") if ctx.attr.platform_version == "": From 19a87c5bc1d5884ff3396250dd84c1d6fbe111df Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 12 Nov 2025 21:42:44 +0000 Subject: [PATCH 02/44] Make jaxlib targets visible --- jaxlib/BUILD | 2 +- jaxlib/rocm/BUILD | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index d9a3a619965c..639803e667d4 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -39,7 +39,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//jax:internal"], + default_visibility = ["//visibility:public"], ) package_group( diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 72f2ec6dae80..4f6c21a10c33 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -27,7 +27,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//visibility:public"], ) cc_library( From 35b2368988872276a084e5c23c1e4c841ce57bb6 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 12 Nov 2025 22:12:01 +0000 Subject: [PATCH 03/44] hipblas typedef fix --- jaxlib/gpu/solver_interface.cc | 12 ++++++------ jaxlib/gpu/vendor.h | 12 ++++-------- jaxlib/rocm/BUILD | 4 ++++ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index e10c08d54e9f..8866c7bea2fe 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -62,8 +62,8 @@ JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf); JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched); JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuComplex, gpublasCgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuDoubleComplex, gpublasZgetrfBatched); #undef JAX_GPU_DEFINE_GETRF_BATCHED // QR decomposition: geqrf @@ -101,8 +101,8 @@ JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf); JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched); JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuComplex, gpublasCgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuDoubleComplex, gpublasZgeqrfBatched); #undef JAX_GPU_DEFINE_GEQRF_BATCHED // Householder transformations: orgqr @@ -272,8 +272,8 @@ JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd); JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk); JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk); -JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); -JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); +JAX_GPU_DEFINE_SYRK(gpuComplex, gpublasCsyrk); +JAX_GPU_DEFINE_SYRK(gpuDoubleComplex, gpublasZsyrk); #undef JAX_GPU_DEFINE_SYRK // Singular Value Decomposition: gesvd diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 4e6c4ca9a7d4..fbe5b30daadf 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -446,6 +446,8 @@ inline constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) +#define HIPBLAS_V2 1 + // IWYU pragma: begin_exports #include "rocm/include/hip/hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" @@ -466,17 +468,11 @@ inline constexpr uint32_t kNumThreadsPerWarp = 32; // MIOpen lib. Remove when MIOpen support is complete. #define MIOPEN_STATUS_SUCCESS 0 -typedef hipFloatComplex gpuComplex; +typedef hipComplex gpuComplex; typedef hipDoubleComplex gpuDoubleComplex; -#if TF_ROCM_VERSION >= 70000 -typedef hipFloatComplex gpublasComplex; +typedef hipComplex gpublasComplex; typedef hipDoubleComplex gpublasDoubleComplex; -#else -typedef hipblasComplex gpublasComplex; -typedef hipblasDoubleComplex gpublasDoubleComplex; -#endif // TF_ROCM_VERSION >= 70000 - typedef struct hipsolverHandle_* gpusolverDnHandle_t; typedef hipblasFillMode_t gpublasFillMode_t; typedef hipsolverFillMode_t gpusolverFillMode_t; diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 4f6c21a10c33..bd810f4a74cf 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -404,6 +404,10 @@ nanobind_extension( "@nanobind", "@xla//xla/ffi/api:ffi", ], + linkopts = [ + "-L/opt/rocm/lib", + "-lamdhip64", + ], ) cc_library( From 9bf2dbf6753dd262b2c08e8e6d55f86ba169d706 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 13 Nov 2025 21:50:02 +0000 Subject: [PATCH 04/44] No GPU fail --- jaxlib/rocm/BUILD | 1 + jaxlib/rocm/rocm_plugin_extension.cc | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index bd810f4a74cf..f266baeb4683 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -535,6 +535,7 @@ nanobind_extension( srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":hip_gpu_kernel_helpers", ":py_client_gpu", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 05be1e81c858..e28e4927b81f 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -23,6 +23,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_plugin_extension.h" #include "jaxlib/gpu/py_client_gpu.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" namespace nb = nanobind; @@ -96,6 +97,13 @@ nb::dict FfiHandlers() { return dict; } +int ROCmDeviceCount() { + int device_count = -1; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipInit(0))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipGetDeviceCount(&device_count))); + return device_count; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { @@ -122,5 +130,6 @@ NB_MODULE(rocm_plugin_extension, m) { return device_ordinal; }, nb::arg("data_value")); + m.def("get_device_count", &ROCmDeviceCount); } } // namespace jax From 7d1708e4716219bc555ae31717e8e9b63d6b48ac Mon Sep 17 00:00:00 2001 From: Marco Minutoli Date: Thu, 12 Feb 2026 14:06:16 -0800 Subject: [PATCH 05/44] Wrap HIP inline functions in anonymous namespaces in vendor.h --- jaxlib/gpu/vendor.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index fbe5b30daadf..915a36e5bdd3 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -529,6 +529,7 @@ inline hipblasStatus_t gpublasCreate(gpublasHandle_t* handle) { return hipblasCreate(reinterpret_cast(handle)); } } // namespace jax::hip + #define gpublasCreate ::jax::hip::gpublasCreate #define gpublasSetStream hipblasSetStream #define gpublasSgeqrfBatched hipblasSgeqrfBatched @@ -585,6 +586,7 @@ inline hipsolverStatus_t gpusolverDnCreate(gpusolverDnHandle_t* handle) { return hipsolverCreate(reinterpret_cast(handle)); } } // namespace jax::hip + #define gpusolverDnCreate ::jax::hip::gpusolverDnCreate #define gpusolverDnSetStream hipsolverSetStream #define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo From 30d7f94d5adecd15b2b1e7edc1003e54daec0360 Mon Sep 17 00:00:00 2001 From: Dragoslav Sicarov Date: Tue, 10 Jun 2025 04:28:40 +0000 Subject: [PATCH 06/44] SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError --- jaxlib/gpu/vendor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 915a36e5bdd3..610a15aecb52 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -772,7 +772,7 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking #define gpuMalloc hipMalloc -#define gpuGetLastError hipGetLastError +#define gpuGetLastError hipExtGetLastError #define gpuGetErrorString hipGetErrorString #define gpuMemcpyAsync hipMemcpyAsync #define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice From a5377e541ba13b735ec1321f5c6fbfbf14114049 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Fri, 14 Nov 2025 19:43:09 +0000 Subject: [PATCH 07/44] Add shared utility function get_rocm_version to test_util.py --- jax/_src/test_util.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index bbfadb69581a..5f24adcd694c 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -25,6 +25,7 @@ import logging import math import os +from pathlib import Path import platform import re import sys @@ -392,6 +393,15 @@ def supported_dtypes(): def is_device_rocm(): return 'rocm' in xla_bridge.get_backend().platform_version +def get_rocm_version(): + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + version_path = Path(rocm_path) / ".info" / "version" + if not version_path.exists(): + raise FileNotFoundError(f"Expected ROCm version file at {version_path}") + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version From db30afa3f9396fe32f905afce3d9a6e2e5520f3c Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Mon, 17 Nov 2025 19:34:29 +0000 Subject: [PATCH 08/44] Fix hipSparse CSR algorithm mappings for ROCm 7 --- jaxlib/gpu/vendor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 610a15aecb52..be985a8a2306 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -756,10 +756,10 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I #define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I #define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT -#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_ALG_DEFAULT -#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_ALG_DEFAULT +#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_COOMV_ALG +#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_CSRMV_ALG1 +#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_COO_ALG1 +#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_CSR_ALG1 #define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO #define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE #define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE From a44f9428580f711818e8c7fab84fd7109e902098 Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Thu, 20 Nov 2025 01:30:24 +0200 Subject: [PATCH 09/44] =?UTF-8?q?Fix=20v=5Fpages=20quantization=20and=20ad?= =?UTF-8?q?just=20test=20params=20for=20ROCm=20compatibilit=E2=80=A6=20(#5?= =?UTF-8?q?60)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/pallas/gpu_paged_attention_test.py | 71 +++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/pallas/gpu_paged_attention_test.py b/tests/pallas/gpu_paged_attention_test.py index 1b778c787a6d..6a1d8de22a78 100644 --- a/tests/pallas/gpu_paged_attention_test.py +++ b/tests/pallas/gpu_paged_attention_test.py @@ -112,6 +112,64 @@ class PagedAttentionKernelTest(PallasBaseTest): def setUp(self): super().setUp() + def _estimate_shared_memory_bytes(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Estimate shared memory usage for paged attention kernel.""" + dtype_size = jnp.dtype(dtype).itemsize + # Approximate calculation based on kernel's memory usage + # Q block: block_h * head_dim + # K/V blocks: pages_per_compute_block * page_size * head_dim + # Plus accumulators and intermediate values + block_k = pages_per_compute_block * page_size + estimated = dtype_size * ( + block_h * head_dim + # Q + 2 * block_k * head_dim + # K and V + block_h * block_k + # logits/attention weights + block_h * 8 # accumulators (m, l, etc.) in float32 + ) + return estimated + + def _adjust_params_for_shared_memory(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Adjust parameters to fit within device shared memory limits. + + Uses XLA's DeviceDescription.shared_memory_per_block_optin() to query + the actual device capability rather than hardcoding values. + """ + try: + device = jax.local_devices()[0] + # Query XLA DeviceDescription for max shared memory per block + # This is exposed from stream_executor::DeviceDescription::shared_memory_per_block_optin() + max_smem = device.shared_memory_per_block_optin + except (AttributeError, IndexError): + # Fallback if XLA doesn't expose shared_memory_per_block_optin (older versions) + # or if no devices are available. Use conservative 48KB (safe for most GPUs). + max_smem = 48 * 1024 + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + # If within limits, no adjustment needed + if estimated <= max_smem: + return block_h, pages_per_compute_block, page_size + + # Try to reduce parameters to fit + while estimated > max_smem: + if pages_per_compute_block > 2: + pages_per_compute_block = pages_per_compute_block // 2 + elif page_size > 8: + page_size = page_size // 2 + elif block_h > 8: + block_h = block_h // 2 + else: + # Can't reduce further, will need to skip + return None, None, None + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + return block_h, pages_per_compute_block, page_size + @jtu.sample_product( dtype=(jnp.float16,), page_size=(8, 16, 32), @@ -201,6 +259,17 @@ def test_quantized_paged_attention( if (quant_dtype == jnp.float8_e4m3fn and not jtu.is_cuda_compute_capability_at_least("8.9")): self.skipTest("Skipping since float8_e4m3fn is not supported on < sm89") + + # Check and adjust parameters if needed to fit device limits for ROCm + if jtu.is_device_rocm(): + adjusted = self._adjust_params_for_shared_memory( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + if adjusted == (None, None, None): + self.skipTest("Cannot adjust parameters to fit ROCm device shared memory limits") + + block_h, pages_per_compute_block, page_size = adjusted + max_kv_len = 2048 seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32) q, k_pages, v_pages, block_tables = _generate_qkv( @@ -218,7 +287,7 @@ def test_quantized_paged_attention( k_, k_scales = (_quantize(k_pages, quant_dtype) if quantize_k else (k_pages, None)) - v_, v_scales = (_quantize(k_pages, quant_dtype) + v_, v_scales = (_quantize(v_pages, quant_dtype) if quantize_v else (v_pages, None)) o = paged_attention.paged_attention( From 01746ea75a4dc464794f102c3574aa698b14b36e Mon Sep 17 00:00:00 2001 From: Aleksei <208770786+Arech8@users.noreply.github.com> Date: Wed, 26 Nov 2025 17:19:56 +0200 Subject: [PATCH 10/44] Address LLVM assertion failure due to a multithreaded use. Update .gitignore (#563) When jaxlib was built in debug more, an assertion in LLVM code that lazy-loads VHLO dialect could fire, since the code path could execute in a multi-threaded environment, and LLVM dialect repositories aren't thread safe to modify. This patch applies the same changes that upstream makes to fix this: https://github.com/jax-ml/jax/commit/48c876227a67840d7bcd44d80d16edbc0e910335 (this includes disabling a call to `jax_mlir_ext.enter_multi_threaded_execution(context)` in `mlir.py`. Presumably, the whole functionality related to `enter_multi_threaded_execution()` multithreaded checks isn't ready yet, and it was prematurely rolled into the production code. Manual testing --- .gitignore | 4 ++++ jax/_src/interpreters/mlir.py | 5 ++--- jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc | 1 - 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index d30c019b31e8..e55a9a5de101 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,7 @@ jax.iml /include/ /lib/ /share/ + +/compile_commands.json +/strace.txt +/external diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index eed5f24e6656..9d9493617b4b 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -571,9 +571,8 @@ def make_ir_context() -> ir.Context: # multi threaded execution aborts the process if we try to register a new # dialect after this point. The dialect registry in a context is not thread # safe, and a fatal error is much better than a data race. - # jax_mlir_ext.enter_multi_threaded_execution(context) - # TODO(phawkins): clean up users who add their own dialects to JAX's contexts - # and enable this. + # if jaxlib_version >= (0, 8): + # jax_mlir_ext.enter_multi_threaded_execution(context) return context diff --git a/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc index f1bf60bb0517..65ee1673ca7c 100644 --- a/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc +++ b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc @@ -204,7 +204,6 @@ NB_MODULE(_jax_mlir_ext, m) { unwrap(registry)->insert(); unwrap(registry)->insert(); unwrap(registry)->insert(); - // For Mosaic GPU REGISTER_DIALECT(cf); REGISTER_DIALECT(gpu); From f555563127b689152a24613cf0ac89d4851c046e Mon Sep 17 00:00:00 2001 From: Aleksei <208770786+Arech8@users.noreply.github.com> Date: Wed, 26 Nov 2025 18:35:19 +0200 Subject: [PATCH 11/44] Add skip of test_is_finite() on Cuda (#565) (forgot this skip in the previous PR) --- tests/pallas/ops_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 58685ee05e18..d56700fd81dd 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1003,6 +1003,10 @@ def test_is_finite(self, dtype): # The original test worked only on fp32@TPU, have no way to test CUDA self.skipTest("Not tested on CUDA, todo for the respective team") + if jtu.test_device_matches(["cuda"]): + self.skipTest("Not tested on CUDA") # set this b/c this how the test was + # originally configured. Have no way to test cuda. + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( From 8cf787ad98a87f5708c02a4e7290ea34d611631f Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Mon, 15 Dec 2025 08:00:42 -0800 Subject: [PATCH 12/44] Add rocm test requirements file (#570) --- build/rocm-test-requirements.txt | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 build/rocm-test-requirements.txt diff --git a/build/rocm-test-requirements.txt b/build/rocm-test-requirements.txt new file mode 100644 index 000000000000..399237175957 --- /dev/null +++ b/build/rocm-test-requirements.txt @@ -0,0 +1,24 @@ +absl-py +build +cloudpickle +colorama>=0.4.4 +filelock +flatbuffers +hypothesis +mpmath>=1.3 +pillow>=10.4.0 +# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t +portpicker; python_version<"3.13" +pytest-xdist +pytest-json-report +pytest-html +pytest-csv +pytest-rerunfailures +pytest-html-merger +pytest-reportlog +wheel +rich +setuptools +matplotlib +opt-einsum +auditwheel From 17e6022f98ff7ff23b38b0df17a9c70e43fe8ba1 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Mon, 15 Dec 2025 11:03:51 -0600 Subject: [PATCH 13/44] Let the unit tests use build.py for setting up Bazel commands for unit tests (#582) --- build/build.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/build/build.py b/build/build.py index 79de19508687..478ae0586c63 100755 --- a/build/build.py +++ b/build/build.py @@ -638,6 +638,10 @@ async def main(): ) if "rocm" in args.wheels: + if not args.configure_only: + print("ERROR: This repo is not used for building the ROCm JAX plugins. Please use the new plugin repo: https://github.com/ROCm/rocm-jax") + exit(1) + wheel_build_command_base.append("--config=rocm_base") wheel_build_command_base.append("--config=rocm") if clang_local: From b600136d29f008ed2d8cf0df521148f4fe6cd464 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Tue, 13 Jan 2026 10:29:56 -0600 Subject: [PATCH 14/44] adding abort logic to rocm/jax (#590) --- conftest.py | 177 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 176 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index fa0e6de94346..72b4b598891c 100644 --- a/conftest.py +++ b/conftest.py @@ -15,7 +15,10 @@ import os import pytest - +import json +import threading +import shutil +from datetime import datetime @pytest.fixture(autouse=True) def add_imports(doctest_namespace): @@ -72,3 +75,175 @@ def pytest_collection() -> None: os.environ.setdefault( "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) + +class ThreadSafeTestLogger: + """Thread-safe logging for parallel test execution and abort detection""" + def __init__(self): + self.locks = {} + self.global_lock = threading.Lock() + self.base_dir = os.path.abspath("./logs") + + # Create logs directory (archiving is handled by test runner scripts) + try: + os.makedirs(self.base_dir, exist_ok=True) + print(f"[TestLogger] Initialized log directory: {self.base_dir}") + except Exception as e: + print(f"[TestLogger] ERROR: Failed to create log directory {self.base_dir}: {e}") + # Fallback to temp directory if logs dir creation fails + import tempfile + self.base_dir = os.path.join(tempfile.gettempdir(), "jax_test_logs") + os.makedirs(self.base_dir, exist_ok=True) + print(f"[TestLogger] Using fallback directory: {self.base_dir}") + + def get_file_lock(self, test_file): + """Get or create a lock for a specific test file""" + with self.global_lock: + if test_file not in self.locks: + self.locks[test_file] = threading.Lock() + return self.locks[test_file] + + def get_test_file_name(self, session): + """Extract the test file name from the session""" + # Try to get from session config args + if hasattr(session, "config") and hasattr(session.config, "args"): + for arg in session.config.args: + # Handle full nodeid like "jax/tests/foo_test.py::TestClass::test_method" + if "tests/" in arg: + # Split on :: to get just the file path + file_path = arg.split("::")[0] + if file_path.endswith(".py"): + return os.path.basename(file_path).replace(".py", "") + + # Try to get from invocation params + if hasattr(session, "config") and hasattr(session.config, "invocation_params"): + invocation_dir = getattr(session.config.invocation_params, "dir", None) + if invocation_dir: + dir_name = os.path.basename(str(invocation_dir)) + if dir_name: + print(f"[TestLogger] Using invocation directory as test name: {dir_name}") + return dir_name + + # Last resort: try to get from session items + if hasattr(session, "items") and session.items: + first_item = session.items[0] + if hasattr(first_item, "fspath"): + fspath = str(first_item.fspath) + if ".py" in fspath: + return os.path.basename(fspath).replace(".py", "") + + print(f"[TestLogger] WARNING: Could not determine test file name, using 'unknown_test'") + print(f"[TestLogger] Session config args: {getattr(session.config, 'args', 'N/A')}") + return "unknown_test" + + def log_running_test(self, test_file, test_name, nodeid, start_time): + """Log the currently running test for abort detection""" + lock = self.get_file_lock(test_file) + with lock: + log_data = { + "test_file": test_file, + "test_name": test_name, + "nodeid": nodeid, + "start_time": start_time, + "status": "running", + "pid": os.getpid(), + "gpu_id": os.environ.get("HIP_VISIBLE_DEVICES", "unknown"), + } + + log_file = f"{self.base_dir}/{test_file}_last_running.json" + try: + # Ensure directory still exists (might have been deleted) + os.makedirs(self.base_dir, exist_ok=True) + with open(log_file, "w") as f: + json.dump(log_data, f, indent=2) + except Exception as e: + print(f"[TestLogger] ERROR: Failed to write running test log to {log_file}: {e}") + print(f"[TestLogger] Current working directory: {os.getcwd()}") + print(f"[TestLogger] Base directory: {self.base_dir}") + print(f"[TestLogger] Base directory exists: {os.path.exists(self.base_dir)}") + raise + + def clear_running_test(self, test_file): + """Clear the running test log when test completes successfully""" + lock = self.get_file_lock(test_file) + with lock: + log_file = f"{self.base_dir}/{test_file}_last_running.json" + if os.path.exists(log_file): + os.remove(log_file) + + +# Global logger instance +test_logger = ThreadSafeTestLogger() + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_protocol(item, nextitem): + """Hook that wraps around each test to track running tests for crash detection. + + This creates a "last_running" file before each test starts and deletes it + when the test completes successfully. If the test crashes, the file remains + and can be detected by the test runner. + """ + test_file = test_logger.get_test_file_name(item.session) + test_name = item.name + nodeid = item.nodeid + start_time = datetime.now().isoformat() + + # Log that this test is starting + try: + test_logger.log_running_test(test_file, test_name, nodeid, start_time) + except Exception as e: + print(f"[TestLogger] WARNING: Failed to log running test: {e}") + # Continue anyway - not critical for test execution + + test_completed = False + try: + outcome = yield + # Test completed (successfully or with normal failure) + test_completed = True + + # Clear the crash detection file + try: + test_logger.clear_running_test(test_file) + except Exception as e: + print(f"[TestLogger] WARNING: Failed to clear running test log: {e}") + + except Exception as e: + # Test raised exception (might be crash, might be normal exception) + print(f"[TestLogger] Test {test_name} exception: {e}") + if not test_completed: + # Don't clear the file - this might be a crash + print(f"[TestLogger] Leaving crash file for detection") + raise + + +@pytest.hookimpl(tryfirst=True) +def pytest_sessionstart(session): + """Called after the Session object has been created""" + gpu = os.environ.get('HIP_VISIBLE_DEVICES', '?') + print(f"Test session starting on GPU {gpu}") + + +@pytest.hookimpl(trylast=True) +def pytest_sessionfinish(session, exitstatus): + """Called after test run finished. + + If a crash file still exists, it means a test crashed and the runner + will detect it. We just report it here for visibility. + """ + test_file = test_logger.get_test_file_name(session) + log_file = f"{test_logger.base_dir}/{test_file}_last_running.json" + + if os.path.exists(log_file): + try: + with open(log_file, "r") as f: + abort_data = json.load(f) + print( + f"\n[CRASH DETECTED] {abort_data.get('nodeid', abort_data.get('test_name', 'unknown'))} " + f"(GPU: {abort_data.get('gpu_id', '?')}, PID: {abort_data.get('pid', '?')})" + ) + print(f"[CRASH DETECTED] Crash file will be processed by test runner") + except Exception as e: + print(f"[TestLogger] WARNING: Crash file exists but unreadable: {e}") + else: + # Normal completion - no crash + pass From 02399d0bd58574732431ab2add16c351247bdeb5 Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Wed, 14 Jan 2026 19:57:26 +0200 Subject: [PATCH 15/44] Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (#597) --- tests/pallas/ops_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d56700fd81dd..c6384725fe9e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1007,6 +1007,9 @@ def test_is_finite(self, dtype): self.skipTest("Not tested on CUDA") # set this b/c this how the test was # originally configured. Have no way to test cuda. + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.8.0") + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( @@ -1057,6 +1060,9 @@ def test_is_finite_scalar(self, dtype): # The original test worked only on fp32@TPU, have no way to test CUDA self.skipTest("Not tested on CUDA, todo for the respective team") + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.8.0") + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( From 0959b0fe3d2da7662ea6f9d9a1982cff66542716 Mon Sep 17 00:00:00 2001 From: Pham Binh Date: Wed, 14 Jan 2026 19:57:53 +0200 Subject: [PATCH 16/44] Fix shared memory limit check for ROCm in test_dot (#596) --- tests/pallas/ops_test.py | 50 +++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index c6384725fe9e..c47f7a4fcf4d 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -74,6 +74,34 @@ def is_power_of_two(n: int) -> bool: return (n > 0) and (n & (n - 1) == 0) +def get_rocm_shared_memory_limit() -> int: + """Get the shared memory (LDS) limit in bytes for ROCm devices. + + Queries rocminfo to get the GROUP segment size dynamically. + Returns 64KB as default if rocminfo fails (MI100/MI200/MI300 all have 64KB LDS). + """ + try: + result = subprocess.run( + ['rocminfo'], capture_output=True, text=True, timeout=10 + ) + if result.returncode != 0: + return 64 * 1024 # Default if rocminfo fails + lines = result.stdout.split('\n') + for i, line in enumerate(lines): + if 'Segment:' in line and 'GROUP' in line: + if i + 1 < len(lines): + size_line = lines[i + 1] + # Match "Size: () KB" with case-insensitive KB check + match = re.search(r'Size:\s+(\d+)\s*\([^)]+\)\s*KB', size_line, re.IGNORECASE) + if match: + size_kb = int(match.group(1)) + return size_kb * 1024 # Convert KB to bytes + except Exception: + pass + # Default for AMD GPUs (MI100/MI200/MI300 all have 64KB LDS) + return 64 * 1024 + + def smem_on_tpu(): if jtu.test_device_matches(["tpu"]): return pltpu.SMEM @@ -2095,12 +2123,22 @@ def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): if jtu.test_device_matches(["gpu"]): if dtype == jnp.bfloat16: self.skipTest("bfloat16 type are not supported on GPU") - if ( - math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) - > (256 * 256) * 2 - ): - self.skipTest("Shared memory size limit exceeded") - if (jax.local_devices()[0].shared_memory_per_block_optin == 99 * 1024 and + # Check shared memory limit: Triton loads lhs + rhs into shared memory + if jtu.is_device_rocm(): + # ROCm: use correct formula with dynamic limit from rocminfo + dtype_size = jnp.dtype(dtype).itemsize + shared_mem_bytes = (math.prod(lhs_shape) + math.prod(rhs_shape)) * dtype_size + shared_mem_limit = get_rocm_shared_memory_limit() + if shared_mem_bytes > shared_mem_limit: + self.skipTest("Shared memory size limit exceeded") + else: + # NVIDIA: keep original check + if ( + math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) + > (256 * 256) * 2 + ): + self.skipTest("Shared memory size limit exceeded") + if (jax.local_devices()[0].device_kind == "NVIDIA L4" and dtype == jnp.float32 and lhs_and_rhs_shape in [ ((128, 16), (128, 256)), From b43ca1801b9ad9cb8006de23a9aed76df47828ac Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Wed, 14 Jan 2026 12:14:02 -0600 Subject: [PATCH 17/44] Fix Numpy signatures test (#598) Co-authored-by: Daniel Suo Co-authored-by: Jake VanderPlas --- tests/lax_numpy_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f69867f56271..68b41b1a0c95 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6397,6 +6397,10 @@ def testWrappedSignaturesMatch(self): 'stack': ['casting'], 'tri': ['like'], 'unravel_index': ['order'], +<<<<<<< HEAD +======= + 'var': ['mean'], +>>>>>>> a3f8af53f (Fix Numpy signatures test (#598)) 'vstack': ['casting'], 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] From cdb5bcb5b508039d027dbf168069ceb0a37b26b8 Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Sun, 18 Jan 2026 10:18:54 -0600 Subject: [PATCH 18/44] fix merge arts --- tests/lax_numpy_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 68b41b1a0c95..f0671fdc0d80 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6397,10 +6397,7 @@ def testWrappedSignaturesMatch(self): 'stack': ['casting'], 'tri': ['like'], 'unravel_index': ['order'], -<<<<<<< HEAD -======= 'var': ['mean'], ->>>>>>> a3f8af53f (Fix Numpy signatures test (#598)) 'vstack': ['casting'], 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] From de1ef417570d1a8d0afb24c13974e1c496db4fb4 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Thu, 22 Jan 2026 15:44:41 -0600 Subject: [PATCH 19/44] Enable RngShardingTests (#644) --- tests/array_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/array_test.py b/tests/array_test.py index 61c30b9ed065..970b93050de4 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1591,7 +1591,7 @@ class RngShardingTest(jtu.JaxTestCase): # tests that the PRNGs are automatically sharded as expected @parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5)) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_1d(self, num_devices): @jax.jit def f(x): @@ -1625,7 +1625,7 @@ def f(x): "mesh_shape": mesh_shape, "pspec": pspec} for mesh_shape in [(3, 2), (4, 2), (2, 3)] for pspec in [P('x', None), P(None, 'y'), P('x', 'y')]) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec): @jax.jit def f(x): From d8179cdfeb875de7fdbc5caa6f2ccd6221ce9a11 Mon Sep 17 00:00:00 2001 From: Marco Minutoli Date: Thu, 12 Feb 2026 15:18:40 -0800 Subject: [PATCH 20/44] Enable test_variadic_reduce_window on ROCm (#647) --- tests/lax_vmap_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 57fdfc4dda88..de260ec4ed93 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -759,6 +759,8 @@ def testSort(self, shape, dimension, arity, bdims, is_stable): # TODO Collapse # TODO Scatter + # b/183233858: variadic reduce-window not implemented on XLA:CUDA + @jtu.skip_on_devices("cuda") def test_variadic_reduce_window(self): # https://github.com/jax-ml/jax/discussions/9818 and # https://github.com/jax-ml/jax/issues/9837 From c5016efc81d95fd6397a35958b33cf2b041c0e75 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 21/44] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 58adf3a42cf1..b5bc7b511ec9 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. @@ -219,6 +220,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. @@ -592,6 +594,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. @@ -1047,6 +1050,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): # hipSPARSE segfault observed as of ROCm 7.2. From 4e6626ea815455a614a3870ead2ff65aa90d14e4 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 22/44] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index b5bc7b511ec9..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,11 +137,8 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") csr_matmul = sparse_csr._csr_matvec if len(bshape) == 1 else sparse_csr._csr_matmat tol = {np.float32: 2E-5, np.float64: 1E-12, np.complex64: 1E-5, @@ -220,11 +217,8 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") op = lambda M: M.T if transpose else M @@ -594,11 +588,8 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") tol = {np.float32: 2E-5, np.float64: 2E-14} @@ -1050,11 +1041,8 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng()) From 694e861d8bd8f505ee3fac761910421129a17e90 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 23/44] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 12e07fb9c442db81a853141303053d393892aa4f Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 24/44] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 76e576f609be03887ca169a5a43fe8f83ea217db Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 25/44] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 237e5ad892e5685df094b77c2d264586fb5f1ce4 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 26/44] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From da3a3cce1e0b9024ff2a5f1372f53e52a586caae Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Wed, 28 Jan 2026 14:00:42 -0800 Subject: [PATCH 27/44] Enable testMultivariateNormalSingularCovariance on ROCm (#666) --- tests/random_lax_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 1842da7c1b25..44314e9034fe 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -963,7 +963,7 @@ def testMultivariateNormalCovariance(self): check_dtypes=False) @jtu.sample_product(method=['cholesky', 'eigh', 'svd']) - @jtu.skip_on_devices('cuda', 'tpu') # Some NaNs on accelerators. + @jtu.skip_on_devices('cuda', 'tpu') # Some NaNs on accelerators. ROCm supported def testMultivariateNormalSingularCovariance(self, method): # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293 mu = jnp.zeros((2,)) From 06d459eb387b1065704f1cf7636ee71f5ac00461 Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Wed, 28 Jan 2026 14:02:39 -0800 Subject: [PATCH 28/44] Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors (#668) --- tests/linalg_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 63c8a28bb76b..754cf5b09e64 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2371,6 +2371,7 @@ def testSelect(self, dtype): @jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)], dtype=float_types + complex_types) + @jtu.skip_on_devices("rocm") # Numerical errors on ROCm def test_tridiagonal_solve(self, shape, dtype): if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") From c30a449424ea16bc58c4352c070779163604a502 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Wed, 28 Jan 2026 16:38:17 -0600 Subject: [PATCH 29/44] Update Skip Reason Outputs (#663) --- jax/_src/test_util.py | 20 ++++++++++++++++---- tests/pallas/gpu_pallas_distributed_test.py | 14 ++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 5f24adcd694c..f6708908da3a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -624,7 +624,14 @@ def skip_on_devices(*disabled_devices, skip_reason=None): skip_reason: Optional custom skip message when test is skipped. """ if skip_reason is None: - skip_reason = "Skipped on devices with tags: " + ", ".join(disabled_devices) + skip_messages = { + ("gpu",): "Skipped on all GPUs.", + ("cpu",): "Skipped on CPU.", + ("tpu",): "Skipped on TPU.", + ("cuda",): "Skipped on CUDA GPUs.", + ("rocm",): "Skipped on ROCm GPUs.", + } + skip_reason = skip_messages.get(disabled_devices) return _device_filter(lambda: not test_device_matches(disabled_devices), skip_reason) def run_on_devices(*enabled_devices, skip_reason=None): @@ -635,9 +642,14 @@ def run_on_devices(*enabled_devices, skip_reason=None): skip_reason: Optional custom skip message when test is skipped. """ if skip_reason is None: - skip_reason = ( - "Skipped unless running on devices with tags: " + ", ".join(enabled_devices) - ) + device_specific_skip_reasons = { + ("cpu",): "Skipped: CPU-only test.", + ("tpu",): "Skipped: TPU-only test.", + ("gpu",): "Skipped: GPU-only test.", + ("rocm",): "Skipped: ROCm-only test.", + ("cuda",): "Skipped: CUDA-only test.", + } + skip_reason = device_specific_skip_reasons.get(enabled_devices) return _device_filter(lambda: test_device_matches(enabled_devices), skip_reason) def device_supports_buffer_donation(): diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index 610965cb8429..dce9fc329187 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -51,20 +51,22 @@ def setUp(self): if jtu.test_device_matches(["rocm"]): self.skipTest("Mosaic not supported on ROCm currently.") + # Check mosaic support first (before GPU capability check) + if not mgpu.supports_cross_device_collectives(): + if jtu.test_device_matches(["rocm"]): + self.skipTest("Mosaic not supported on ROCm currently.") + else: + self.skipTest("NVSHMEM library unavailable.") if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") - if not mgpu.supports_cross_device_collectives(): - self.skipTest( - "Skip test since cross-device collectives are not supported" - " (either NVSHMEM is not available in multi-process mode, or mixed" - " mode is used).") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": self.skipTest("NVSHMEM doesn't work with the platform allocator.") super().setUp() - class PallasCallRemoteDMATest(TestCase): def test_remote_dma_basic(self): From 58ce4e11ac008ca0b9fd089987360a7d762560a1 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 30/44] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From e8307d2ccbcfacd01f96d4f514d334c5d72ec1d8 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 31/44] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 64ee74ec4c122ffbeeaa51f15fe863206175529e Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Thu, 29 Jan 2026 13:12:07 -0600 Subject: [PATCH 32/44] Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677) --- tests/array_interoperability_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index d033a490f547..b14e9171f952 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -312,7 +312,7 @@ def testCudaArrayInterfaceOnNonCudaFails(self): self.assertFalse(hasattr(x, "__cuda_array_interface__")) with self.assertRaisesRegex( AttributeError, - "__cuda_array_interface__ is only defined for .*GPU buffers.", + "__cuda_array_interface__ is only defined for GPU buffers.", ): _ = x.__cuda_array_interface__ From f1441325f14a7f503ded94da97523973dc5a1b25 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 33/44] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 9dd1698ee3180549f3f6de94f010725971d8f24c Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 34/44] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From fd1195e3498615101089e26232f57f0456cf2ac7 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 13:38:16 -0600 Subject: [PATCH 35/44] Skip sparse tests on ROCm due to hipSPARSE issue (#652) --- tests/sparse_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 30240ff69c50..fe6a035f75c8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,6 +137,7 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -217,6 +218,7 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -588,6 +590,7 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1041,6 +1044,7 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") + @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From 4af53277cd7861b9501c06e4c0525201aedcaeb1 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 23 Jan 2026 15:35:23 -0600 Subject: [PATCH 36/44] Update sparse test skip messages in v0.8.2 (#653) --- tests/sparse_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index fe6a035f75c8..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -137,7 +137,6 @@ def test_csr_fromdense_ad(self, shape, dtype): dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") @@ -218,7 +217,6 @@ def test_csr_fromdense(self, shape, dtype): dtype=all_dtypes, transpose=[True, False], ) - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") @@ -590,7 +588,6 @@ def test_coo_spmm(self, shape, dtype, transpose): transpose=[True, False], ) @jtu.run_on_devices("gpu") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") @@ -1044,7 +1041,6 @@ def test_transpose(self, shape, dtype, Obj): ) for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])) @jax.default_matmul_precision("float32") - @jtu.skip_on_devices("rocm") # skipping on ROCm due to known issue in hipSPARSE def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") From bfe02082b45aebb6abeccb41ca6e696406abf9ae Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Thu, 5 Feb 2026 10:50:39 -0800 Subject: [PATCH 37/44] Add ROCm encoding for test_struct_encoding_determinism (#683) --- tests/experimental_rnn_test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 4bb611dcd842..61e79c00a09e 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -179,7 +179,7 @@ def f(weights, x, h_0, c_0): y_padded = y_ref[i, seq_lengths[i]:] np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded)) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_struct_encoding_determinism(self): def f(k1, k2, k3, k4): batch_size = 1 @@ -213,8 +213,15 @@ def f(k1, k2, k3, k4): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', - stablehlo) + # Platform-specific binary encodings for RnnDescriptor + cuda_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"' + rocm_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\008\\00\\00\\00\\00\\00\\00\\00\\1C\\00\\00\\00\\00\\00\\00\\00"' + + # Check that one of the expected encodings is present + if jtu.test_device_matches(["cuda"]): + self.assertIn(cuda_encoding, stablehlo) + elif jtu.test_device_matches(["rocm"]): + self.assertIn(rocm_encoding, stablehlo) # Note: Other LSTM tests that use `bidirectional=True` on ROCm are skipped # because of current numerical issues (as of ROCm 7.1.1). However, this From 44b8a6c30e15c0fdae7ded4d7563776af4ac59b9 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 6 Feb 2026 12:48:48 -0600 Subject: [PATCH 38/44] Remove 'mean' from unsupported params for jnp.var (#689) --- tests/lax_numpy_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f0671fdc0d80..f69867f56271 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6397,7 +6397,6 @@ def testWrappedSignaturesMatch(self): 'stack': ['casting'], 'tri': ['like'], 'unravel_index': ['order'], - 'var': ['mean'], 'vstack': ['casting'], 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] From 7bd4a132eee29e2ebfe97f7cb046b3b64069e119 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Fri, 6 Feb 2026 16:49:53 -0600 Subject: [PATCH 39/44] Implement approx_tanh for ROCm using OCML tanh function (#691) --- jax/_src/pallas/triton/primitives.py | 92 ++++++++++++++++++++++++++++ tests/pallas/ops_test.py | 41 +++++++++++++ 2 files changed, 133 insertions(+) diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8e763c3d8e6a..25423ebba471 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -48,6 +48,11 @@ def approx_tanh(x: jax.Array) -> jax.Array: elif x.dtype == jnp.float32: asm = "tanh.approx.f32 $0, $1;" constraint = "f" + elif x.dtype == jnp.float64: + # f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64) + # CUDA does not have a PTX instruction for f64 approximate tanh + asm = "tanh.approx.f64 $0, $1;" + constraint = "d" else: raise TypeError(f"approx_tanh does not accept {x.dtype} arrays") @@ -119,6 +124,13 @@ def _elementwise_inline_asm_lowering( result_shape_dtypes, ): del result_shape_dtypes # Unused. + + # For ROCm, PTX inline assembly is not supported. For tanh.approx, we use + # Triton's __triton_hip_fast_tanhf (fast exp-based formula) for f32, and + # OCML's __ocml_tanh_f64 for f64. See: https://github.com/triton-lang/triton/pull/7780 + if ctx.context.platform == "rocm" and "tanh.approx" in asm: + return _approx_tanh_rocm_lowering(ctx, *args) + return tt_dialect.ElementwiseInlineAsmOp( [*map(mlir.aval_to_ir_type, ctx.avals_out)], asm, @@ -129,6 +141,86 @@ def _elementwise_inline_asm_lowering( ).result +def _approx_tanh_rocm_lowering( + ctx: lowering.LoweringRuleContext, + *args, +): + """Lower approx_tanh for ROCm. + + AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction. + + For f32 (and f16/bf16 via casting): We use Triton's __triton_hip_fast_tanhf + which implements a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + See: https://github.com/triton-lang/triton/pull/7780 + + For f64: We use OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) + since fast_tanhf only supports f32. + """ + from jax._src.lib.mlir import ir + from jax._src.lib.mlir.dialects import arith as arith_dialect + + [arg] = args + [out_aval] = ctx.avals_out + in_dtype = ctx.avals_in[0].dtype + + # Helper to get IR type for a dtype + def dtype_to_ir_type(dtype): + dtype = jnp.dtype(dtype) + return mlir.dtype_to_ir_type(dtype) + + # f64: use __ocml_tanh_f64 (fast_tanhf only supports f32) + if in_dtype == jnp.float64: + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="", + libpath="", + symbol="__ocml_tanh_f64", + pure=True, + ) + return [result] + + # fast_tanhf only supports f32. For f16/bf16, cast to f32, compute, cast back. + needs_cast = in_dtype in (jnp.float16, jnp.bfloat16) + + if needs_cast: + # Cast input to f32 (extend) + f32_type = dtype_to_ir_type(jnp.float32) + if out_aval.shape: + f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type) + else: + f32_result_type = f32_type + arg_f32 = arith_dialect.extf(f32_result_type, arg) + + # Call __triton_hip_fast_tanhf (fast exp-based implementation) + tanh_result = tt_dialect.extern_elementwise( + f32_result_type, + [arg_f32], + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + # Cast result back to original dtype (truncate) + out_type = mlir.aval_to_ir_type(out_aval) + result = arith_dialect.truncf(out_type, tanh_result) + else: + # f32: call __triton_hip_fast_tanhf directly + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + return [result] + + def debug_barrier() -> None: """Synchronizes all kernel executions in the grid.""" return debug_barrier_p.bind() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index c47f7a4fcf4d..02a721b55555 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1910,6 +1910,47 @@ def kernel(o_ref): np.testing.assert_allclose(f(), kernel()) + @parameterized.parameters("float16", "bfloat16", "float32", "float64") + def test_approx_tanh(self, dtype): + self.skip_if_mosaic_gpu() + + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + if self.INTERPRET: + self.skipTest("approx_tanh is not supported in interpret mode") + + if (dtype == "bfloat16" and + jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + + if dtype == "float64": + if jtu.test_device_matches(["cuda"]): + self.skipTest("f64 approx_tanh is only supported on ROCm") + + # Enable x64 for f64 test if not already enabled, restore after test + original_x64 = jax.config.x64_enabled + if dtype == "float64" and not original_x64: + jax.config.update("jax_enable_x64", True) + self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu_triton.approx_tanh(x_ref[...]) + + x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/jax-ml/jax/issues/11014. + np.testing.assert_allclose( + kernel(x).astype(jnp.float32), + jnp.tanh(x).astype(jnp.float32), + atol=5e-3, + rtol=5e-3, + ) + @parameterized.parameters( ((2, 4), (8,)), ((2, 4), (8, 1)), From 95ae9faa1db1d514d2100960bbc99240d0acec5f Mon Sep 17 00:00:00 2001 From: AratiGanesh Date: Mon, 9 Feb 2026 05:08:48 -0800 Subject: [PATCH 40/44] Skipping testEighTinyNorm due to hipSolver issues (#697) --- tests/linalg_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 754cf5b09e64..5b38275947ff 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -570,7 +570,12 @@ def testEighZeroDiagonal(self): np.linalg.norm(np.matmul(a, v) - w * v), 2.5 * eps * np.linalg.norm(a) ) + def testEighTinyNorm(self): + if jtu.is_device_rocm(): + # numerical errors seen as of ROCm 7.2 due to hipSolver issue + # TODO: re-enable the test once the hipSolver issue is fixed + self.skipTest("testEighNorm not supported on ROCm due to hipSOLVER issue") rng = jtu.rand_default(self.rng()) a = rng((300, 300), dtype=np.float32) eps = jnp.finfo(a.dtype).eps From e355fcd697d47e97249e102d3e69924c5347b6e9 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Fri, 20 Feb 2026 11:17:09 -0600 Subject: [PATCH 41/44] Abort detection CI workflow (#688) --- .github/workflows/pytest_rocm_abort.yml | 162 ++++++++++++ .../wheel_tests_nightly_release_abort.yml | 53 ++++ ci/run_pytest_rocm_abort.sh | 179 ++++++++++++++ conftest.py | 233 +++++------------- 4 files changed, 456 insertions(+), 171 deletions(-) create mode 100644 .github/workflows/pytest_rocm_abort.yml create mode 100644 .github/workflows/wheel_tests_nightly_release_abort.yml create mode 100755 ci/run_pytest_rocm_abort.sh diff --git a/.github/workflows/pytest_rocm_abort.yml b/.github/workflows/pytest_rocm_abort.yml new file mode 100644 index 000000000000..fcfdb06855d5 --- /dev/null +++ b/.github/workflows/pytest_rocm_abort.yml @@ -0,0 +1,162 @@ +# CI - Pytest ROCm (Abort Support) +# +# This workflow runs the ROCm tests with Pytest in ROCm GHCR containers, +# using the ROCm `pytest-abort` retry wrapper to detect/retry aborts/crashes. +# +# It can be triggered manually via workflow_dispatch or called by other workflows +# via workflow_call. +# +# It consists of the following job: +# run-tests: +# - Runs in ROCm container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest) +# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from latest release. +# - Executes the `run_pytest_rocm_abort.sh` script, which installs wheel artifacts and +# runs the ROCm tests with Pytest under `pytest-abort-retry`. +name: CI - Pytest ROCm (Abort Support) + +on: + workflow_dispatch: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: choice + default: "linux-x86-64-4gpu-amd" + options: + - "linux-x86-64-1gpu-amd" + - "linux-x86-64-4gpu-amd" + - "linux-x86-64-8gpu-amd" + python: + description: "Which Python version to use?" + type: choice + default: "3.11" + options: + - "3.11" + - "3.12" + rocm-version: + description: "Which ROCm version to test?" + type: choice + default: "7.2.0" + options: + - "7.2.0" + rocm-tag: + description: "ROCm tag for container image (e.g., rocm720)" + type: string + default: "rocm720" + jaxlib-version: + description: "Which jaxlib version to use? (head/pypi_latest)" + type: choice + default: "head" + options: + - "head" + - "pypi_latest" + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)" + type: choice + default: '0' + options: + - '0' + - '1' + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + type: string + default: 'gs://jax-nightly-artifacts/latest' + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-64-4gpu-amd" + python: + description: "Which Python version to use?" + type: string + default: "3.11" + rocm-version: + description: "Which ROCm version to test?" + type: string + default: "7.2.0" + rocm-tag: + description: "ROCm tag for container image (e.g., rocm720)" + type: string + default: "rocm720" + jaxlib-version: + description: "Which jaxlib version to use? (head/pypi_latest)" + type: string + default: "head" + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + default: 'gs://jax-nightly-artifacts/latest' + type: string + +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + +jobs: + run-tests: + defaults: + run: + # Set the shell to bash as GitHub actions run with /bin/sh by default + shell: bash + runs-on: ${{ inputs.runner }} + continue-on-error: true + # Run in ROCm GHCR container with GPU access + container: + image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --shm-size 64G --env-file /etc/podinfo/gha-gpu-isolation-settings + name: "${{ (contains(inputs.runner, '1gpu') && '1gpu') || + (contains(inputs.runner, '4gpu') && '4gpu') || + (contains(inputs.runner, '8gpu') && '8gpu') }}, ROCm ${{ inputs.rocm-version }}, py${{ inputs.python }}" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_ENABLE_X64: "0" + + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX ROCm wheels + uses: ./.github/actions/download-jax-rocm-wheels + with: + python: ${{ inputs.python }} + rocm-version: ${{ inputs.rocm-version }} + jaxlib-version: ${{ inputs.jaxlib-version }} + skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Install Python dependencies + run: | + $JAXCI_PYTHON -m pip install uv~=0.5.30 + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Pytest ROCm tests (abort support) + timeout-minutes: 180 + run: ./ci/run_pytest_rocm_abort.sh + - name: Upload pytest results to artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: logs_abort + path: | + logs_abort/ + if-no-files-found: warn + retention-days: 2 + overwrite: true diff --git a/.github/workflows/wheel_tests_nightly_release_abort.yml b/.github/workflows/wheel_tests_nightly_release_abort.yml new file mode 100644 index 000000000000..518306956eff --- /dev/null +++ b/.github/workflows/wheel_tests_nightly_release_abort.yml @@ -0,0 +1,53 @@ +# CI - Wheel Tests (Nightly/Release) (ROCm abort only) +# +# This workflow runs only the ROCm wheel tests using the abort/retry wrapper workflow. +name: CI - Wheel Tests (Nightly/Release) (ROCm abort only) + +on: + workflow_dispatch: + inputs: + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" + required: true + default: 'gs://jax-nightly-artifacts/latest' + type: string + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection? (yes/no)' + required: false + default: 'no' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + +jobs: + run-pytest-rocm: + uses: ./.github/workflows/pytest_rocm_abort.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"] + python: ["3.11", "3.12", "3.13", "3.14"] + rocm: [ + {version: "7.2.0", tag: "rocm720"}, + ] + name: "Pytest ROCm abort (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm.version }} + rocm-tag: ${{ matrix.rocm.tag }} + jaxlib-version: "head" + skip-download-jaxlib-and-plugins-from-gcs: ${{inputs.skip-download-jaxlib-and-plugins-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} diff --git a/ci/run_pytest_rocm_abort.sh b/ci/run_pytest_rocm_abort.sh new file mode 100755 index 000000000000..f17e7184b3bf --- /dev/null +++ b/ci/run_pytest_rocm_abort.sh @@ -0,0 +1,179 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pytest ROCm tests (with ROCm pytest-abort retry wrapper). +# Requires the jaxlib and ROCm plugin wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +source ci/envs/default.env + +# Install jaxlib and ROCm plugin wheels inside the $JAXCI_OUTPUT_DIR directory +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Print all the installed packages +echo "Installed packages:" +"$JAXCI_PYTHON" -m uv pip freeze + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +rocm-smi + +# ============================================================================== +# Set up the generic test environment variables +# ============================================================================== +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true +export NCCL_DEBUG=WARN +export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_X64="$JAXCI_ENABLE_X64" + +# ============================================================================== +# Calculate the optimal number of parallel processes for pytest +# This will be the minimum of: GPU capacity, CPU core count, and a system RAM limit. +# ============================================================================== + +export gpu_count=$(rocminfo | egrep -c "Device Type:\\s+GPU") +echo "Number of GPUs detected: $gpu_count" + +# Query GPU 0 memory using rocm-smi +export memory_per_gpu_mib=$(rocm-smi -d 0 --showmeminfo vram | grep -i "vram total" | awk '{print int($NF/1024/1024)}' | head -1) +echo "Reported memory per GPU: $memory_per_gpu_mib MiB" + +# Convert effective memory from MiB to GiB. +export memory_per_gpu_gib=$((memory_per_gpu_mib / 1024)) +echo "Effective memory per GPU: $memory_per_gpu_gib GiB" + +# Allow 2 GiB of GPU RAM per test. +export max_tests_per_gpu=$((memory_per_gpu_gib / 2)) +echo "Max tests per GPU (assuming 2GiB/test): $max_tests_per_gpu" + +export num_processes=$((gpu_count * max_tests_per_gpu)) +echo "Initial number of processes based on GPU capacity: $num_processes" + +export num_cpu_cores=$(nproc) +echo "Number of CPU cores available: $num_cpu_cores" + +# Reads total memory from /proc/meminfo (in KiB) and converts to GiB. +export total_ram_gib=$(awk '/MemTotal/ {printf \"%.0f\", $2/1048576}' /proc/meminfo) +echo "Total system RAM: $total_ram_gib GiB" + +# Set a safety limit for system RAM usage, e.g., 1/6th of total. +export host_memory_limit=$((total_ram_gib / 6)) +echo "Host memory process limit (1/6th of total RAM): $host_memory_limit" + +if [[ $num_cpu_cores -lt $num_processes ]]; then + num_processes=$num_cpu_cores + echo "Adjusting num_processes to match CPU core count: $num_processes" +fi + +if [[ $host_memory_limit -lt $num_processes ]]; then + num_processes=$host_memory_limit + echo "Adjusting num_processes to match host memory limit: $num_processes" +fi + +if [[ 16 -lt $num_processes ]]; then + num_processes=16 + echo "Reducing num_processes to $num_processes" +fi + +echo "Final number of processes to run: $num_processes" + +export JAX_ENABLE_ROCM_XDIST="$gpu_count" +export XLA_PYTHON_CLIENT_ALLOCATOR=platform +export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1 --xla_gpu_enable_nccl_comm_splitting=false --xla_gpu_enable_command_buffer=" + +# Disable core dumps just in case +ulimit -c 0 + +# Keep deselected tests in one place for the abort wrapper. +ROCM_PYTEST_DESELECT_ARGS=( + --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data + --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices + --deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric +) + +# --max-runs: retry the entire pytest run up to N times on abort/crash. +# --max-worker-restart: restart crashed xdist workers up to N times. +# --maxfail: stop the run after N test failures. +rocm_test_cmd() { + local abort_flag="${1:-0}" + shift + if [[ "$abort_flag" == "1" ]]; then + pytest-abort-retry --max-runs 3 --clear-crash-log -- "$@" + else + "$@" + fi +} + +rocm_log_tail_on_failure() { + local logfile="$1" + local status="$2" + if [[ "$status" -ne 0 ]]; then + echo "Pytest failed (exit=$status). Showing last 200 lines of $logfile:" + tail -n 200 "$logfile" || true + else + echo "Pytest output saved to $logfile (uploaded as artifact)." + fi +} + +rocm_install_extra_requirements() { + if [[ -n "${GITHUB_WORKSPACE:-}" ]]; then + cd "$GITHUB_WORKSPACE" + fi + + # Install extra requirements. + "$JAXCI_PYTHON" -m uv pip install pytest-timeout pytest-html pytest-csv pytest-json-report pytest-abort +} + +rocm_install_extra_requirements + +echo "Running ROCm tests (with abort/retry wrapper)..." +mkdir -p logs_abort +logfile="logs_abort/jax_ToT_UT_abort.log" + +# pytest-abort output directories (must be set before running pytest). +export PYTEST_ABORT_LAST_RUNNING_DIR="logs_abort/last_running" +export PYTEST_ABORT_CRASHED_TESTS_LOG="logs_abort/crashed_tests.jsonl" +mkdir -p "$PYTEST_ABORT_LAST_RUNNING_DIR" + +set +e +rocm_test_cmd 1 "$JAXCI_PYTHON" -m pytest -n "$num_processes" --max-worker-restart=200 --tb=short --timeout=1200 --timeout-method=thread tests \ + "${ROCM_PYTEST_DESELECT_ARGS[@]}" \ + --json-report \ + --json-report-file=logs_abort/tests-report-abort.json \ + --csv=logs_abort/tests-report-abort.csv \ + --html=logs_abort/tests-report-abort.html \ + --self-contained-html \ + >"$logfile" 2>&1 +pytest_status=$? +set -e +rocm_log_tail_on_failure "$logfile" "$pytest_status" + +echo "Postprocessing reports with crashed tests..." +pytest-abort-postprocess \ + --crash-log "$PYTEST_ABORT_CRASHED_TESTS_LOG" \ + --json-report logs_abort/tests-report-abort.json \ + --html-report logs_abort/tests-report-abort.html \ + --csv-report logs_abort/tests-report-abort.csv \ + >>"$logfile" 2>&1 + +exit "$pytest_status" diff --git a/conftest.py b/conftest.py index 72b4b598891c..ca144a6dcee4 100644 --- a/conftest.py +++ b/conftest.py @@ -15,10 +15,6 @@ import os import pytest -import json -import threading -import shutil -from datetime import datetime @pytest.fixture(autouse=True) def add_imports(doctest_namespace): @@ -76,174 +72,69 @@ def pytest_collection() -> None: "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) -class ThreadSafeTestLogger: - """Thread-safe logging for parallel test execution and abort detection""" - def __init__(self): - self.locks = {} - self.global_lock = threading.Lock() - self.base_dir = os.path.abspath("./logs") - - # Create logs directory (archiving is handled by test runner scripts) - try: - os.makedirs(self.base_dir, exist_ok=True) - print(f"[TestLogger] Initialized log directory: {self.base_dir}") - except Exception as e: - print(f"[TestLogger] ERROR: Failed to create log directory {self.base_dir}: {e}") - # Fallback to temp directory if logs dir creation fails - import tempfile - self.base_dir = os.path.join(tempfile.gettempdir(), "jax_test_logs") - os.makedirs(self.base_dir, exist_ok=True) - print(f"[TestLogger] Using fallback directory: {self.base_dir}") - - def get_file_lock(self, test_file): - """Get or create a lock for a specific test file""" - with self.global_lock: - if test_file not in self.locks: - self.locks[test_file] = threading.Lock() - return self.locks[test_file] - - def get_test_file_name(self, session): - """Extract the test file name from the session""" - # Try to get from session config args - if hasattr(session, "config") and hasattr(session.config, "args"): - for arg in session.config.args: - # Handle full nodeid like "jax/tests/foo_test.py::TestClass::test_method" - if "tests/" in arg: - # Split on :: to get just the file path - file_path = arg.split("::")[0] - if file_path.endswith(".py"): - return os.path.basename(file_path).replace(".py", "") - - # Try to get from invocation params - if hasattr(session, "config") and hasattr(session.config, "invocation_params"): - invocation_dir = getattr(session.config.invocation_params, "dir", None) - if invocation_dir: - dir_name = os.path.basename(str(invocation_dir)) - if dir_name: - print(f"[TestLogger] Using invocation directory as test name: {dir_name}") - return dir_name - - # Last resort: try to get from session items - if hasattr(session, "items") and session.items: - first_item = session.items[0] - if hasattr(first_item, "fspath"): - fspath = str(first_item.fspath) - if ".py" in fspath: - return os.path.basename(fspath).replace(".py", "") - - print(f"[TestLogger] WARNING: Could not determine test file name, using 'unknown_test'") - print(f"[TestLogger] Session config args: {getattr(session.config, 'args', 'N/A')}") - return "unknown_test" - - def log_running_test(self, test_file, test_name, nodeid, start_time): - """Log the currently running test for abort detection""" - lock = self.get_file_lock(test_file) - with lock: - log_data = { - "test_file": test_file, - "test_name": test_name, - "nodeid": nodeid, - "start_time": start_time, - "status": "running", - "pid": os.getpid(), - "gpu_id": os.environ.get("HIP_VISIBLE_DEVICES", "unknown"), - } - - log_file = f"{self.base_dir}/{test_file}_last_running.json" - try: - # Ensure directory still exists (might have been deleted) - os.makedirs(self.base_dir, exist_ok=True) - with open(log_file, "w") as f: - json.dump(log_data, f, indent=2) - except Exception as e: - print(f"[TestLogger] ERROR: Failed to write running test log to {log_file}: {e}") - print(f"[TestLogger] Current working directory: {os.getcwd()}") - print(f"[TestLogger] Base directory: {self.base_dir}") - print(f"[TestLogger] Base directory exists: {os.path.exists(self.base_dir)}") - raise - - def clear_running_test(self, test_file): - """Clear the running test log when test completes successfully""" - lock = self.get_file_lock(test_file) - with lock: - log_file = f"{self.base_dir}/{test_file}_last_running.json" - if os.path.exists(log_file): - os.remove(log_file) - - -# Global logger instance -test_logger = ThreadSafeTestLogger() - - -@pytest.hookimpl(hookwrapper=True) -def pytest_runtest_protocol(item, nextitem): - """Hook that wraps around each test to track running tests for crash detection. - - This creates a "last_running" file before each test starts and deletes it - when the test completes successfully. If the test crashes, the file remains - and can be detected by the test runner. - """ - test_file = test_logger.get_test_file_name(item.session) - test_name = item.name - nodeid = item.nodeid - start_time = datetime.now().isoformat() - - # Log that this test is starting - try: - test_logger.log_running_test(test_file, test_name, nodeid, start_time) - except Exception as e: - print(f"[TestLogger] WARNING: Failed to log running test: {e}") - # Continue anyway - not critical for test execution - - test_completed = False - try: - outcome = yield - # Test completed (successfully or with normal failure) - test_completed = True - - # Clear the crash detection file - try: - test_logger.clear_running_test(test_file) - except Exception as e: - print(f"[TestLogger] WARNING: Failed to clear running test log: {e}") - - except Exception as e: - # Test raised exception (might be crash, might be normal exception) - print(f"[TestLogger] Test {test_name} exception: {e}") - if not test_completed: - # Don't clear the file - this might be a crash - print(f"[TestLogger] Leaving crash file for detection") - raise + elif num_rocm_devices := os.environ.get("JAX_ENABLE_ROCM_XDIST", None): + num_rocm_devices = int(num_rocm_devices) + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + assigned = str(xdist_worker_number % num_rocm_devices) + # If ROCR_VISIBLE_DEVICES is set, don't also set HIP_VISIBLE_DEVICES + # (double-filtering can produce HIP_ERROR_NoDevice). Respect the outer setting. + if os.environ.get("ROCR_VISIBLE_DEVICES"): + return -@pytest.hookimpl(tryfirst=True) -def pytest_sessionstart(session): - """Called after the Session object has been created""" - gpu = os.environ.get('HIP_VISIBLE_DEVICES', '?') - print(f"Test session starting on GPU {gpu}") + # If present-but-empty, this can hide all GPUs. + if os.environ.get("HIP_VISIBLE_DEVICES", None) == "": + del os.environ["HIP_VISIBLE_DEVICES"] + + # HIP layer isolation (ROCm also accepts CUDA_VISIBLE_DEVICES, but we avoid it here). + os.environ["HIP_VISIBLE_DEVICES"] = assigned + +def pytest_configure(config) -> None: + # Real pytest hook (runs early in main + each xdist worker). + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") or "main" + + # xdist master: print planned mapping (worker stdout is often hidden) + numproc = int(getattr(getattr(config, "option", None), "numprocesses", 0) or 0) + if xdist_worker_name == "main" and numproc > 0: + hip0 = (os.environ.get("HIP_VISIBLE_DEVICES") or "").strip() + cuda_x = (os.environ.get("JAX_ENABLE_CUDA_XDIST") or "").strip() + tpu_x = (os.environ.get("JAX_ENABLE_TPU_XDIST") or "").strip() + rocm_x = (os.environ.get("JAX_ENABLE_ROCM_XDIST") or "").strip() + if cuda_x: + try: + ndev = int(cuda_x) + except ValueError: + ndev = 0 + if ndev > 0: + mapping = ", ".join(f"gw{i}->CUDA_VISIBLE_DEVICES={i % ndev}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif tpu_x: + mapping = ", ".join(f"gw{i}->TPU_VISIBLE_CHIPS={i}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif rocm_x: + try: + ndev = int(rocm_x) + except ValueError: + ndev = 0 + if ndev > 0: + mapping = ", ".join(f"gw{i}->HIP_VISIBLE_DEVICES={i % ndev}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif hip0: + print(f"[DeviceVisibility] master HIP_VISIBLE_DEVICES={hip0}", flush=True) + if os.environ.get("JAX_ENABLE_TPU_XDIST", None): + if xdist_worker_name.startswith("gw"): + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") -@pytest.hookimpl(trylast=True) -def pytest_sessionfinish(session, exitstatus): - """Called after test run finished. - - If a crash file still exists, it means a test crashed and the runner - will detect it. We just report it here for visibility. - """ - test_file = test_logger.get_test_file_name(session) - log_file = f"{test_logger.base_dir}/{test_file}_last_running.json" - - if os.path.exists(log_file): - try: - with open(log_file, "r") as f: - abort_data = json.load(f) - print( - f"\n[CRASH DETECTED] {abort_data.get('nodeid', abort_data.get('test_name', 'unknown'))} " - f"(GPU: {abort_data.get('gpu_id', '?')}, PID: {abort_data.get('pid', '?')})" - ) - print(f"[CRASH DETECTED] Crash file will be processed by test runner") - except Exception as e: - print(f"[TestLogger] WARNING: Crash file exists but unreadable: {e}") - else: - # Normal completion - no crash - pass + elif num_cuda_devices := os.environ.get("JAX_ENABLE_CUDA_XDIST", None): + if xdist_worker_name.startswith("gw"): + num_cuda_devices = int(num_cuda_devices) + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault( + "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) + ) From d36ebc22060387976b8059af9556e825745e4f8c Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Tue, 24 Feb 2026 09:59:51 -0600 Subject: [PATCH 42/44] Abort-Detection: Fix halt-for-connection input (#712) --- .github/workflows/pytest_rocm_abort.yml | 13 +++++++++++++ ci/run_pytest_rocm_abort.sh | 7 +++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest_rocm_abort.yml b/.github/workflows/pytest_rocm_abort.yml index fcfdb06855d5..854b6681978b 100644 --- a/.github/workflows/pytest_rocm_abort.yml +++ b/.github/workflows/pytest_rocm_abort.yml @@ -64,6 +64,10 @@ on: description: 'Should this workflow run wait for a remote connection?' type: string default: 'no' + max-worker-restart: + description: "Max xdist worker restarts (passed to pytest --max-worker-restart)" + type: string + default: '50' workflow_call: inputs: runner: @@ -94,6 +98,14 @@ on: description: "GCS location prefix from where the artifacts should be downloaded" default: 'gs://jax-nightly-artifacts/latest' type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + max-worker-restart: + description: "Max xdist worker restarts (passed to pytest --max-worker-restart)" + type: string + default: '50' permissions: {} @@ -123,6 +135,7 @@ jobs: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_ENABLE_X64: "0" + MAX_WORKER_RESTART: "${{ inputs['max-worker-restart'] }}" steps: - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 diff --git a/ci/run_pytest_rocm_abort.sh b/ci/run_pytest_rocm_abort.sh index f17e7184b3bf..6c6ab6d18281 100755 --- a/ci/run_pytest_rocm_abort.sh +++ b/ci/run_pytest_rocm_abort.sh @@ -73,7 +73,7 @@ export num_cpu_cores=$(nproc) echo "Number of CPU cores available: $num_cpu_cores" # Reads total memory from /proc/meminfo (in KiB) and converts to GiB. -export total_ram_gib=$(awk '/MemTotal/ {printf \"%.0f\", $2/1048576}' /proc/meminfo) +export total_ram_gib=$(awk '/MemTotal/ {printf "%.0f", $2/1048576}' /proc/meminfo) echo "Total system RAM: $total_ram_gib GiB" # Set a safety limit for system RAM usage, e.g., 1/6th of total. @@ -150,13 +150,16 @@ echo "Running ROCm tests (with abort/retry wrapper)..." mkdir -p logs_abort logfile="logs_abort/jax_ToT_UT_abort.log" +# Allow the workflow to override worker restart limit. +max_worker_restart="${MAX_WORKER_RESTART:-50}" + # pytest-abort output directories (must be set before running pytest). export PYTEST_ABORT_LAST_RUNNING_DIR="logs_abort/last_running" export PYTEST_ABORT_CRASHED_TESTS_LOG="logs_abort/crashed_tests.jsonl" mkdir -p "$PYTEST_ABORT_LAST_RUNNING_DIR" set +e -rocm_test_cmd 1 "$JAXCI_PYTHON" -m pytest -n "$num_processes" --max-worker-restart=200 --tb=short --timeout=1200 --timeout-method=thread tests \ +rocm_test_cmd 1 "$JAXCI_PYTHON" -m pytest -n "$num_processes" --max-worker-restart="$max_worker_restart" --tb=short --timeout=1200 --timeout-method=thread tests \ "${ROCM_PYTEST_DESELECT_ARGS[@]}" \ --json-report \ --json-report-file=logs_abort/tests-report-abort.json \ From e793527227eee820c75c1e77e77e43b4d1e181e2 Mon Sep 17 00:00:00 2001 From: Pakize Sanal Date: Wed, 25 Feb 2026 13:03:33 +0000 Subject: [PATCH 43/44] Drop call-id input from reusable workflow --- .github/workflows/pytest_rocm.yml | 83 +++++++++++++++++++ .../workflows/wheel_tests_nightly_release.yml | 1 + build/test-requirements.txt | 1 + ci/run_pytest_rocm.sh | 2 + 4 files changed, 87 insertions(+) diff --git a/.github/workflows/pytest_rocm.yml b/.github/workflows/pytest_rocm.yml index ed7a5a6dc5e9..8ddefd192b04 100644 --- a/.github/workflows/pytest_rocm.yml +++ b/.github/workflows/pytest_rocm.yml @@ -93,6 +93,13 @@ on: description: "GCS location prefix from where the artifacts should be downloaded" default: 'gs://jax-nightly-artifacts/latest' type: string + secrets: + AWS_ACCESS_KEY_ID: + required: true + AWS_SECRET_ACCESS_KEY: + required: true + S3_BUCKET_NAME: + required: true permissions: {} env: @@ -148,3 +155,79 @@ jobs: - name: Run Pytest ROCm tests timeout-minutes: 120 run: ./ci/run_pytest_rocm.sh + - name: Create a logs archive + if: always() + run: | + set -euo pipefail + tar -czf logs.tar.gz logs + - name: Collect run-manifest info + if: always() + env: + INPUT_RUNNER: ${{ inputs.runner }} + INPUT_PYTHON: ${{ inputs.python }} + INPUT_ROCM_VERSION: ${{ inputs.rocm-version }} + INPUT_ROCM_TAG: ${{ inputs.rocm-tag }} + run: | + PKGS="$($JAXCI_PYTHON -m pip list | grep -E '^(jax|jaxlib)|pjrt|plugin' || true)" + PKGS_ONE_LINE="$(printf "%s" "$PKGS" | tr '\n' '|' | sed 's/|$//')" + WHEELS="$(sha256sum dist/*.whl 2>/dev/null || true)" + WHEELS_ONE_LINE="$(printf "%s" "$WHEELS" | tr '\n' '|' | sed 's/|$//')" + + REPO="rocm/jax-base-ubu24.${INPUT_ROCM_TAG}" + IMAGE="ghcr.io/${REPO}:latest" + BASE="https://ghcr.io" + DIGEST="" + TOKEN="$(curl -fsSL "${BASE}/token?service=ghcr.io&scope=repository:${REPO}:pull" \ + | sed -n 's/.*"token":"\([^"]*\)".*/\1/p' || true)" + if [ -n "${TOKEN}" ]; then + DIGEST="$(curl -fsSL -D - \ + -H "Authorization: Bearer ${TOKEN}" \ + -H "Accept: application/vnd.docker.distribution.manifest.v2+json" \ + "${BASE}/v2/${REPO}/manifests/latest" -o /dev/null \ + | awk -F': ' 'tolower($1)=="docker-content-digest"{print $2}' \ + | tr -d $'\r' || true)" + fi + + cat > logs/run-manifest.json << EOF + { + "schema_version": 1, + "created_at": "$(date -u +"%Y-%m-%dT%H:%M:%SZ")", + "github_repository": "${{ github.repository }}", + "github_ref_name": "${{ github.ref_name }}", + "github_ref": "${{ github.ref }}", + "github_sha": "${{ github.sha }}", + "github_event_name": "${{ github.event_name }}", + "github_run_id": "${{ github.run_id }}", + "github_run_attempt": "${{ github.run_attempt }}", + "github_run_number": "${{ github.run_number }}", + "github_workflow": "${{ github.workflow }}", + "github_job": "${{ github.job }}", + "python_version": "${INPUT_PYTHON}", + "rocm_version": "${INPUT_ROCM_VERSION}", + "runner": "${INPUT_RUNNER}", + "base_image_name": "${IMAGE}", + "base_image_digest": "${DIGEST}", + "jax_packages_raw": "${PKGS_ONE_LINE}", + "wheels_sha_raw": "${WHEELS_ONE_LINE}" + } + EOF + - name: Upload logs to AMD S3 bucket + if: always() + env: + S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + RUN_KEY: repo=${{ github.repository }}/run_id=${{ github.run_id }}/att=${{ github.run_attempt }} + JOB_KEY: job=${{ inputs.call-id || github.job }} + COMBO: py=${{ inputs.python }}/rocm=${{ inputs.rocm-version }}/runner=${{ inputs.runner }} + run: | + set -euo pipefail + PREFIX="logs/${RUN_KEY}/${JOB_KEY}/${COMBO}/" + $JAXCI_PYTHON -m pip install -q boto3 + $JAXCI_PYTHON - <=11.3 portpicker pytest<9.0 # Works around https://github.com/pytest-dev/pytest/issues/13895 pytest-xdist +pytest-json-report rich matplotlib auditwheel diff --git a/ci/run_pytest_rocm.sh b/ci/run_pytest_rocm.sh index b22034cf2f62..1b1f26d9fc0f 100755 --- a/ci/run_pytest_rocm.sh +++ b/ci/run_pytest_rocm.sh @@ -114,8 +114,10 @@ echo "Running ROCm tests..." # TODO: Add examples directory to test suite (CUDA tests both: tests examples) # TODO: Verify if CSV/HTML report generation should be kept (unique to ROCm) # TODO: Verify if log file output should be kept (unique to ROCm) +mkdir -p logs export NPROC=32 "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short \ +--json-report --json-report-file=logs/pytest_results.json \ tests \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ From c58f3e9e1cd9401ad2506f29cbe7e527c050397e Mon Sep 17 00:00:00 2001 From: Pakize Sanal Date: Sat, 28 Feb 2026 19:07:58 +0000 Subject: [PATCH 44/44] Add upload_rocm_logs.sh to push CI logs and manifest to S3 --- .github/workflows/pytest_rocm.yml | 80 ++-------- .../workflows/wheel_tests_nightly_release.yml | 3 + ci/run_pytest_rocm.sh | 6 +- ci/upload_rocm_logs.sh | 146 ++++++++++++++++++ 4 files changed, 167 insertions(+), 68 deletions(-) create mode 100755 ci/upload_rocm_logs.sh diff --git a/.github/workflows/pytest_rocm.yml b/.github/workflows/pytest_rocm.yml index 8ddefd192b04..8bac684debc3 100644 --- a/.github/workflows/pytest_rocm.yml +++ b/.github/workflows/pytest_rocm.yml @@ -100,7 +100,9 @@ on: required: true S3_BUCKET_NAME: required: true -permissions: {} +permissions: + actions: read + contents: read env: UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" @@ -155,79 +157,25 @@ jobs: - name: Run Pytest ROCm tests timeout-minutes: 120 run: ./ci/run_pytest_rocm.sh - - name: Create a logs archive + - name: Archive test logs if: always() run: | set -euo pipefail tar -czf logs.tar.gz logs - - name: Collect run-manifest info - if: always() - env: - INPUT_RUNNER: ${{ inputs.runner }} - INPUT_PYTHON: ${{ inputs.python }} - INPUT_ROCM_VERSION: ${{ inputs.rocm-version }} - INPUT_ROCM_TAG: ${{ inputs.rocm-tag }} + - name: Configure AWS Credentials run: | - PKGS="$($JAXCI_PYTHON -m pip list | grep -E '^(jax|jaxlib)|pjrt|plugin' || true)" - PKGS_ONE_LINE="$(printf "%s" "$PKGS" | tr '\n' '|' | sed 's/|$//')" - WHEELS="$(sha256sum dist/*.whl 2>/dev/null || true)" - WHEELS_ONE_LINE="$(printf "%s" "$WHEELS" | tr '\n' '|' | sed 's/|$//')" - - REPO="rocm/jax-base-ubu24.${INPUT_ROCM_TAG}" - IMAGE="ghcr.io/${REPO}:latest" - BASE="https://ghcr.io" - DIGEST="" - TOKEN="$(curl -fsSL "${BASE}/token?service=ghcr.io&scope=repository:${REPO}:pull" \ - | sed -n 's/.*"token":"\([^"]*\)".*/\1/p' || true)" - if [ -n "${TOKEN}" ]; then - DIGEST="$(curl -fsSL -D - \ - -H "Authorization: Bearer ${TOKEN}" \ - -H "Accept: application/vnd.docker.distribution.manifest.v2+json" \ - "${BASE}/v2/${REPO}/manifests/latest" -o /dev/null \ - | awk -F': ' 'tolower($1)=="docker-content-digest"{print $2}' \ - | tr -d $'\r' || true)" - fi - - cat > logs/run-manifest.json << EOF - { - "schema_version": 1, - "created_at": "$(date -u +"%Y-%m-%dT%H:%M:%SZ")", - "github_repository": "${{ github.repository }}", - "github_ref_name": "${{ github.ref_name }}", - "github_ref": "${{ github.ref }}", - "github_sha": "${{ github.sha }}", - "github_event_name": "${{ github.event_name }}", - "github_run_id": "${{ github.run_id }}", - "github_run_attempt": "${{ github.run_attempt }}", - "github_run_number": "${{ github.run_number }}", - "github_workflow": "${{ github.workflow }}", - "github_job": "${{ github.job }}", - "python_version": "${INPUT_PYTHON}", - "rocm_version": "${INPUT_ROCM_VERSION}", - "runner": "${INPUT_RUNNER}", - "base_image_name": "${IMAGE}", - "base_image_digest": "${DIGEST}", - "jax_packages_raw": "${PKGS_ONE_LINE}", - "wheels_sha_raw": "${WHEELS_ONE_LINE}" - } - EOF - - name: Upload logs to AMD S3 bucket + echo "config AWS cred." + - name: Upload test-artifacts to AMD S3 if: always() env: S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - RUN_KEY: repo=${{ github.repository }}/run_id=${{ github.run_id }}/att=${{ github.run_attempt }} - JOB_KEY: job=${{ inputs.call-id || github.job }} - COMBO: py=${{ inputs.python }}/rocm=${{ inputs.rocm-version }}/runner=${{ inputs.runner }} + GITHUB_TOKEN: ${{ github.token }} + INPUT_PYTHON: ${{ inputs.python }} + INPUT_ROCM_VERSION: ${{ inputs.rocm-version }} + INPUT_RUNNER: ${{ inputs.runner }} + INPUT_ROCM_TAG: ${{ inputs.rocm-tag }} + IS_NIGHTLY: ${{ contains(github.workflow, 'Nightly/Release') && 'nightly' || 'continuous' }} run: | - set -euo pipefail - PREFIX="logs/${RUN_KEY}/${JOB_KEY}/${COMBO}/" - $JAXCI_PYTHON -m pip install -q boto3 - $JAXCI_PYTHON - <////__// +set -euo pipefail + +: "${S3_BUCKET_NAME:?}" +: "${INPUT_PYTHON:?}" +: "${INPUT_ROCM_VERSION:?}" +: "${INPUT_ROCM_TAG:?}" +: "${INPUT_RUNNER:?}" +: "${IS_NIGHTLY:?}" # nightly|continuous + +TEST_LOGS_ROOT="jax-ci-test-logs" + +norm() { printf '%s' "$1" | tr '.-' '_' ; } + +# Timestamp: GitHub run_started_at; fall back to UTC date. +RUN_STARTED_AT="" +if [[ -n "${GITHUB_TOKEN:-}" ]]; then + RUN_STARTED_AT="$( + curl -fsSL \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github+json" \ + "https://api.github.com/repos/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" \ + | sed -n 's/.*"run_started_at"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p' \ + | head -n1 || true + )" +fi + +DATE="${RUN_STARTED_AT%%T*}" +[[ -n "${DATE}" ]] || DATE="$(date -u +%F)" + +# GPU count from runner name (e.g. linux-x86-64-8gpu-amd -> 8). +GPU_COUNT="" +if [[ "${INPUT_RUNNER}" =~ ([0-9]+)gpu ]]; then + GPU_COUNT="${BASH_REMATCH[1]}" +fi + +GPU_PART="${GPU_COUNT:+gpu_${GPU_COUNT}}" +GPU_PART="${GPU_PART:-${INPUT_RUNNER}}" + +RUN_DIR="${DATE}_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}" +COMBO="py$(norm "${INPUT_PYTHON}")-rocm$(norm "${INPUT_ROCM_VERSION}")-${GPU_PART}" +PREFIX="${GITHUB_REPOSITORY}/${GITHUB_REF_NAME}/${IS_NIGHTLY}/${RUN_DIR}/${COMBO}" + +DEST="s3://${S3_BUCKET_NAME}/${TEST_LOGS_ROOT}/${PREFIX}" + +echo "Uploading ROCm pytest artifacts" + +# Upload archive first (created in YAML) +ARCHIVE="logs.tar.gz" +[[ -f "${ARCHIVE}" ]] || { echo "Missing ${ARCHIVE}"; exit 2; } + +echo "Uploading logs.tar.gz" +aws s3 cp --only-show-errors "${ARCHIVE}" "${DEST}/${ARCHIVE}" + +PYTHON="${JAXCI_PYTHON:-python3}" +# Packages/wheels (best-effort) +PKGS_RAW="$( + "${PYTHON}" -m pip list --format=freeze 2>/dev/null \ + | grep -E '^(jax|jaxlib)==|pjrt|plugin' \ + || true +)" +PKGS_ONE_LINE="$(printf "%s" "${PKGS_RAW}" | tr '\n' '|' | sed 's/|$//')" + +WHEELS_RAW="$(sha256sum dist/*.whl 2>/dev/null || true)" +WHEELS_ONE_LINE="$(printf "%s" "${WHEELS_RAW}" | tr '\n' '|' | sed 's/|$//')" + +# Base image digest (best-effort) +GHCR_REPO="rocm/jax-base-ubu24.${INPUT_ROCM_TAG}" +IMAGE="ghcr.io/${GHCR_REPO}:latest" +DIGEST="" + +TOKEN="$( + curl -fsSL "https://ghcr.io/token?service=ghcr.io&scope=repository:${GHCR_REPO}:pull" \ + | sed -n 's/.*"token"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p' \ + || true +)" +if [[ -n "${TOKEN}" ]]; then + DIGEST="$( + curl -fsSL -D - \ + -H "Authorization: Bearer ${TOKEN}" \ + -H "Accept: application/vnd.docker.distribution.manifest.v2+json" \ + "https://ghcr.io/v2/${GHCR_REPO}/manifests/latest" -o /dev/null \ + | awk -F': ' 'tolower($1)=="docker-content-digest"{print $2}' \ + | tr -d $'\r' \ + | head -n1 || true + )" +fi + +RUN_URL="https://github.com/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" + +echo "Uploading run-manifest.json" +aws s3 cp --only-show-errors - "${DEST}/run-manifest.json" <