diff --git a/.github/workflows/pytest_rocm_abort.yml b/.github/workflows/pytest_rocm_abort.yml new file mode 100644 index 000000000000..fcfdb06855d5 --- /dev/null +++ b/.github/workflows/pytest_rocm_abort.yml @@ -0,0 +1,162 @@ +# CI - Pytest ROCm (Abort Support) +# +# This workflow runs the ROCm tests with Pytest in ROCm GHCR containers, +# using the ROCm `pytest-abort` retry wrapper to detect/retry aborts/crashes. +# +# It can be triggered manually via workflow_dispatch or called by other workflows +# via workflow_call. +# +# It consists of the following job: +# run-tests: +# - Runs in ROCm container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest) +# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from latest release. +# - Executes the `run_pytest_rocm_abort.sh` script, which installs wheel artifacts and +# runs the ROCm tests with Pytest under `pytest-abort-retry`. +name: CI - Pytest ROCm (Abort Support) + +on: + workflow_dispatch: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: choice + default: "linux-x86-64-4gpu-amd" + options: + - "linux-x86-64-1gpu-amd" + - "linux-x86-64-4gpu-amd" + - "linux-x86-64-8gpu-amd" + python: + description: "Which Python version to use?" + type: choice + default: "3.11" + options: + - "3.11" + - "3.12" + rocm-version: + description: "Which ROCm version to test?" + type: choice + default: "7.2.0" + options: + - "7.2.0" + rocm-tag: + description: "ROCm tag for container image (e.g., rocm720)" + type: string + default: "rocm720" + jaxlib-version: + description: "Which jaxlib version to use? (head/pypi_latest)" + type: choice + default: "head" + options: + - "head" + - "pypi_latest" + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)" + type: choice + default: '0' + options: + - '0' + - '1' + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + type: string + default: 'gs://jax-nightly-artifacts/latest' + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-64-4gpu-amd" + python: + description: "Which Python version to use?" + type: string + default: "3.11" + rocm-version: + description: "Which ROCm version to test?" + type: string + default: "7.2.0" + rocm-tag: + description: "ROCm tag for container image (e.g., rocm720)" + type: string + default: "rocm720" + jaxlib-version: + description: "Which jaxlib version to use? (head/pypi_latest)" + type: string + default: "head" + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + default: 'gs://jax-nightly-artifacts/latest' + type: string + +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + +jobs: + run-tests: + defaults: + run: + # Set the shell to bash as GitHub actions run with /bin/sh by default + shell: bash + runs-on: ${{ inputs.runner }} + continue-on-error: true + # Run in ROCm GHCR container with GPU access + container: + image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --shm-size 64G --env-file /etc/podinfo/gha-gpu-isolation-settings + name: "${{ (contains(inputs.runner, '1gpu') && '1gpu') || + (contains(inputs.runner, '4gpu') && '4gpu') || + (contains(inputs.runner, '8gpu') && '8gpu') }}, ROCm ${{ inputs.rocm-version }}, py${{ inputs.python }}" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_ENABLE_X64: "0" + + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX ROCm wheels + uses: ./.github/actions/download-jax-rocm-wheels + with: + python: ${{ inputs.python }} + rocm-version: ${{ inputs.rocm-version }} + jaxlib-version: ${{ inputs.jaxlib-version }} + skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }} + gcs_download_uri: ${{ inputs.gcs_download_uri }} + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Install Python dependencies + run: | + $JAXCI_PYTHON -m pip install uv~=0.5.30 + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Pytest ROCm tests (abort support) + timeout-minutes: 180 + run: ./ci/run_pytest_rocm_abort.sh + - name: Upload pytest results to artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: logs_abort + path: | + logs_abort/ + if-no-files-found: warn + retention-days: 2 + overwrite: true diff --git a/.github/workflows/wheel_tests_nightly_release_abort.yml b/.github/workflows/wheel_tests_nightly_release_abort.yml new file mode 100644 index 000000000000..518306956eff --- /dev/null +++ b/.github/workflows/wheel_tests_nightly_release_abort.yml @@ -0,0 +1,53 @@ +# CI - Wheel Tests (Nightly/Release) (ROCm abort only) +# +# This workflow runs only the ROCm wheel tests using the abort/retry wrapper workflow. +name: CI - Wheel Tests (Nightly/Release) (ROCm abort only) + +on: + workflow_dispatch: + inputs: + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" + required: true + default: 'gs://jax-nightly-artifacts/latest' + type: string + skip-download-jaxlib-and-plugins-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection? (yes/no)' + required: false + default: 'no' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true +permissions: {} + +env: + UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple" + +jobs: + run-pytest-rocm: + uses: ./.github/workflows/pytest_rocm_abort.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"] + python: ["3.11", "3.12", "3.13", "3.14"] + rocm: [ + {version: "7.2.0", tag: "rocm720"}, + ] + name: "Pytest ROCm abort (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm.version }} + rocm-tag: ${{ matrix.rocm.tag }} + jaxlib-version: "head" + skip-download-jaxlib-and-plugins-from-gcs: ${{inputs.skip-download-jaxlib-and-plugins-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} diff --git a/.gitignore b/.gitignore index d30c019b31e8..e55a9a5de101 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,7 @@ jax.iml /include/ /lib/ /share/ + +/compile_commands.json +/strace.txt +/external diff --git a/build/build.py b/build/build.py index 79de19508687..478ae0586c63 100755 --- a/build/build.py +++ b/build/build.py @@ -638,6 +638,10 @@ async def main(): ) if "rocm" in args.wheels: + if not args.configure_only: + print("ERROR: This repo is not used for building the ROCm JAX plugins. Please use the new plugin repo: https://github.com/ROCm/rocm-jax") + exit(1) + wheel_build_command_base.append("--config=rocm_base") wheel_build_command_base.append("--config=rocm") if clang_local: diff --git a/build/rocm-test-requirements.txt b/build/rocm-test-requirements.txt new file mode 100644 index 000000000000..399237175957 --- /dev/null +++ b/build/rocm-test-requirements.txt @@ -0,0 +1,24 @@ +absl-py +build +cloudpickle +colorama>=0.4.4 +filelock +flatbuffers +hypothesis +mpmath>=1.3 +pillow>=10.4.0 +# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t +portpicker; python_version<"3.13" +pytest-xdist +pytest-json-report +pytest-html +pytest-csv +pytest-rerunfailures +pytest-html-merger +pytest-reportlog +wheel +rich +setuptools +matplotlib +opt-einsum +auditwheel diff --git a/ci/run_pytest_rocm_abort.sh b/ci/run_pytest_rocm_abort.sh new file mode 100755 index 000000000000..f17e7184b3bf --- /dev/null +++ b/ci/run_pytest_rocm_abort.sh @@ -0,0 +1,179 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pytest ROCm tests (with ROCm pytest-abort retry wrapper). +# Requires the jaxlib and ROCm plugin wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +source ci/envs/default.env + +# Install jaxlib and ROCm plugin wheels inside the $JAXCI_OUTPUT_DIR directory +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Print all the installed packages +echo "Installed packages:" +"$JAXCI_PYTHON" -m uv pip freeze + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +rocm-smi + +# ============================================================================== +# Set up the generic test environment variables +# ============================================================================== +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true +export NCCL_DEBUG=WARN +export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_X64="$JAXCI_ENABLE_X64" + +# ============================================================================== +# Calculate the optimal number of parallel processes for pytest +# This will be the minimum of: GPU capacity, CPU core count, and a system RAM limit. +# ============================================================================== + +export gpu_count=$(rocminfo | egrep -c "Device Type:\\s+GPU") +echo "Number of GPUs detected: $gpu_count" + +# Query GPU 0 memory using rocm-smi +export memory_per_gpu_mib=$(rocm-smi -d 0 --showmeminfo vram | grep -i "vram total" | awk '{print int($NF/1024/1024)}' | head -1) +echo "Reported memory per GPU: $memory_per_gpu_mib MiB" + +# Convert effective memory from MiB to GiB. +export memory_per_gpu_gib=$((memory_per_gpu_mib / 1024)) +echo "Effective memory per GPU: $memory_per_gpu_gib GiB" + +# Allow 2 GiB of GPU RAM per test. +export max_tests_per_gpu=$((memory_per_gpu_gib / 2)) +echo "Max tests per GPU (assuming 2GiB/test): $max_tests_per_gpu" + +export num_processes=$((gpu_count * max_tests_per_gpu)) +echo "Initial number of processes based on GPU capacity: $num_processes" + +export num_cpu_cores=$(nproc) +echo "Number of CPU cores available: $num_cpu_cores" + +# Reads total memory from /proc/meminfo (in KiB) and converts to GiB. +export total_ram_gib=$(awk '/MemTotal/ {printf \"%.0f\", $2/1048576}' /proc/meminfo) +echo "Total system RAM: $total_ram_gib GiB" + +# Set a safety limit for system RAM usage, e.g., 1/6th of total. +export host_memory_limit=$((total_ram_gib / 6)) +echo "Host memory process limit (1/6th of total RAM): $host_memory_limit" + +if [[ $num_cpu_cores -lt $num_processes ]]; then + num_processes=$num_cpu_cores + echo "Adjusting num_processes to match CPU core count: $num_processes" +fi + +if [[ $host_memory_limit -lt $num_processes ]]; then + num_processes=$host_memory_limit + echo "Adjusting num_processes to match host memory limit: $num_processes" +fi + +if [[ 16 -lt $num_processes ]]; then + num_processes=16 + echo "Reducing num_processes to $num_processes" +fi + +echo "Final number of processes to run: $num_processes" + +export JAX_ENABLE_ROCM_XDIST="$gpu_count" +export XLA_PYTHON_CLIENT_ALLOCATOR=platform +export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1 --xla_gpu_enable_nccl_comm_splitting=false --xla_gpu_enable_command_buffer=" + +# Disable core dumps just in case +ulimit -c 0 + +# Keep deselected tests in one place for the abort wrapper. +ROCM_PYTEST_DESELECT_ARGS=( + --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data + --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices + --deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric +) + +# --max-runs: retry the entire pytest run up to N times on abort/crash. +# --max-worker-restart: restart crashed xdist workers up to N times. +# --maxfail: stop the run after N test failures. +rocm_test_cmd() { + local abort_flag="${1:-0}" + shift + if [[ "$abort_flag" == "1" ]]; then + pytest-abort-retry --max-runs 3 --clear-crash-log -- "$@" + else + "$@" + fi +} + +rocm_log_tail_on_failure() { + local logfile="$1" + local status="$2" + if [[ "$status" -ne 0 ]]; then + echo "Pytest failed (exit=$status). Showing last 200 lines of $logfile:" + tail -n 200 "$logfile" || true + else + echo "Pytest output saved to $logfile (uploaded as artifact)." + fi +} + +rocm_install_extra_requirements() { + if [[ -n "${GITHUB_WORKSPACE:-}" ]]; then + cd "$GITHUB_WORKSPACE" + fi + + # Install extra requirements. + "$JAXCI_PYTHON" -m uv pip install pytest-timeout pytest-html pytest-csv pytest-json-report pytest-abort +} + +rocm_install_extra_requirements + +echo "Running ROCm tests (with abort/retry wrapper)..." +mkdir -p logs_abort +logfile="logs_abort/jax_ToT_UT_abort.log" + +# pytest-abort output directories (must be set before running pytest). +export PYTEST_ABORT_LAST_RUNNING_DIR="logs_abort/last_running" +export PYTEST_ABORT_CRASHED_TESTS_LOG="logs_abort/crashed_tests.jsonl" +mkdir -p "$PYTEST_ABORT_LAST_RUNNING_DIR" + +set +e +rocm_test_cmd 1 "$JAXCI_PYTHON" -m pytest -n "$num_processes" --max-worker-restart=200 --tb=short --timeout=1200 --timeout-method=thread tests \ + "${ROCM_PYTEST_DESELECT_ARGS[@]}" \ + --json-report \ + --json-report-file=logs_abort/tests-report-abort.json \ + --csv=logs_abort/tests-report-abort.csv \ + --html=logs_abort/tests-report-abort.html \ + --self-contained-html \ + >"$logfile" 2>&1 +pytest_status=$? +set -e +rocm_log_tail_on_failure "$logfile" "$pytest_status" + +echo "Postprocessing reports with crashed tests..." +pytest-abort-postprocess \ + --crash-log "$PYTEST_ABORT_CRASHED_TESTS_LOG" \ + --json-report logs_abort/tests-report-abort.json \ + --html-report logs_abort/tests-report-abort.html \ + --csv-report logs_abort/tests-report-abort.csv \ + >>"$logfile" 2>&1 + +exit "$pytest_status" diff --git a/conftest.py b/conftest.py index fa0e6de94346..9578612e498e 100644 --- a/conftest.py +++ b/conftest.py @@ -16,6 +16,180 @@ import os import pytest +# Mosaic GPU checking based on test *file path* only (avoid test-name substrings). +_MOSAIC_GPU_PATH_NEEDLES = ( + f"{os.sep}tests{os.sep}mosaic{os.sep}", + f"{os.sep}tests{os.sep}pallas{os.sep}mgpu_", + f"{os.sep}tests{os.sep}pallas{os.sep}mosaic_gpu", + f"{os.sep}tests{os.sep}pallas{os.sep}mosaic", +) + +# Simple Mosaic GPU *usage* substring checks (avoid import-only signals). +_MOSAIC_GPU_SOURCE_NEEDLES = ( + "inline_mgpu", + "plgpu_mgpu.", + "mosaic_gpu_interpret", + "mosaic_gpu_backend", + "jax.experimental.mosaic.gpu", # runtime usage in body (not module import scan) + "jax.experimental.pallas.mosaic_gpu", +) + + +def _pallas_defaults_to_mosaic_gpu() -> bool: + """Returns True if Pallas GPU lowering defaults to Mosaic GPU.""" + try: + from jax._src.pallas import pallas_call as pallas_call_lib # pytype: disable=import-error + return bool(pallas_call_lib._PALLAS_USE_MOSAIC_GPU.value) # pylint: disable=protected-access + except Exception: + return False + + +def _running_on_rocm() -> bool: + """Best-effort ROCm detection. + + First tries to check rocm in jaxlib version, falls back to checking backend + platform_version so that it works for ROCm PJRT plugin installs where jaxlib's + version tag may not contain rocm. + """ + try: + import jaxlib.version as jaxlib_version # pytype: disable=import-error + version_str = getattr(jaxlib_version, "__version__", "") + except Exception: + version_str = "" + if "rocm" in version_str.lower(): + return True + try: + import jax # pytype: disable=import-error + from jax._src import xla_bridge # pytype: disable=import-error + backend = xla_bridge.get_backend() + pv = getattr(backend, "platform_version", "") or "" + return "rocm" in str(pv).lower() + except Exception: + return False + + +def _source_mentions_mosaic_gpu(src: str) -> bool: + """Returns True if the test file content has Mosaic GPU usage.""" + lowered = src.lower() + return any(n in lowered for n in _MOSAIC_GPU_SOURCE_NEEDLES) + + +def _looks_like_mosaic_gpu_path(path_str: str) -> bool: + """Returns True if the path is a Mosaic-GPU-only test file.""" + lowered = path_str.lower() + return any(n.lower() in lowered for n in _MOSAIC_GPU_PATH_NEEDLES) + + +def _class_mosaic_override(cls: type | None, cache: dict[object, object]) -> bool | None: + """Detects explicit class-level Mosaic enable/disable. + + Returns: + - True if the class forces Mosaic GPU (`_PALLAS_USE_MOSAIC_GPU(True)`). + - False if it forces Triton (`_PALLAS_USE_MOSAIC_GPU(False)`). + - None if no explicit override is found. + """ + if cls is None: + return None + cache_key = ("__mosaic_override__", cls) + if cache_key in cache: + return cache[cache_key] # type: ignore[return-value] + import inspect + try: + src = inspect.getsource(cls).lower() + except Exception: + cache[cache_key] = None + return None + if "_pallas_use_mosaic_gpu(true" in src: + cache[cache_key] = True + return True + if "_pallas_use_mosaic_gpu(false" in src: + cache[cache_key] = False + return False + cache[cache_key] = None + return None + + +def _is_mosaic_gpu_item( + item: pytest.Item, + cache: dict[object, bool], + *, + running_on_rocm: bool, + pallas_defaults_to_mosaic: bool, +) -> bool: + """Returns True if this test item uses (or would use) Mosaic GPU.""" + path_obj = getattr(item, "path", None) or getattr(item, "fspath", None) + path_str = str(path_obj) if path_obj is not None else "" + if _looks_like_mosaic_gpu_path(path_str): + return True + + import inspect + + obj = getattr(item, "obj", None) + if obj is None: + return False + if obj in cache: + return cache[obj] + try: + src = inspect.getsource(obj) + except Exception: + cache[obj] = False + return False + + lowered = src.lower() + # Direct Mosaic usage in the test function/method. + if _source_mentions_mosaic_gpu(lowered): + cache[obj] = True + return True + + # Respect explicit class-level override: if a test class forces Mosaic off, + # we should not skip it just because Pallas defaults to Mosaic elsewhere. + cls_override = _class_mosaic_override(getattr(item, "cls", None), cache) # type: ignore[arg-type] + if cls_override is False: + cache[obj] = False + return False + if cls_override is True: + cache[obj] = True + return True + + # Implicit Mosaic usage: on ROCm, `pallas_call` defaults to Mosaic GPU when + # `compiler_params` is not specified and Mosaic is the default backend. + if running_on_rocm and pallas_defaults_to_mosaic: + uses_pallas_call = ( + ".pallas_call" in lowered + or "pl.pallas_call" in lowered + or "pallas_call(" in lowered + ) + explicitly_selects_compiler = "compiler_params=" in lowered + if uses_pallas_call and not explicitly_selects_compiler: + cache[obj] = True + return True + + cache[obj] = False + return False + + +def pytest_collection_modifyitems( + config: pytest.Config, items: list[pytest.Item] +) -> None: + """Mark Mosaic GPU tests and skip them on ROCm.""" + running_on_rocm = _running_on_rocm() + pallas_defaults_to_mosaic = _pallas_defaults_to_mosaic_gpu() if running_on_rocm else False + cache: dict[object, bool] = {} + for item in items: + is_mosaic_gpu = _is_mosaic_gpu_item( + item, + cache, + running_on_rocm=running_on_rocm, + pallas_defaults_to_mosaic=pallas_defaults_to_mosaic, + ) + if not is_mosaic_gpu: + continue + item.add_marker(pytest.mark.mosaic_gpu) + if running_on_rocm: + item.add_marker(pytest.mark.skip( + reason="Mosaic GPU tests are not supported on ROCm" + )) + @pytest.fixture(autouse=True) def add_imports(doctest_namespace): @@ -72,3 +246,76 @@ def pytest_collection() -> None: os.environ.setdefault( "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) + + elif num_rocm_devices := os.environ.get("JAX_ENABLE_ROCM_XDIST", None): + num_rocm_devices = int(num_rocm_devices) + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + assigned = str(xdist_worker_number % num_rocm_devices) + + # If ROCR_VISIBLE_DEVICES is set, don't also set HIP_VISIBLE_DEVICES + # (double-filtering can produce HIP_ERROR_NoDevice). Respect the outer setting. + if os.environ.get("ROCR_VISIBLE_DEVICES"): + return + + # If present-but-empty, this can hide all GPUs. + if os.environ.get("HIP_VISIBLE_DEVICES", None) == "": + del os.environ["HIP_VISIBLE_DEVICES"] + + # HIP layer isolation (ROCm also accepts CUDA_VISIBLE_DEVICES, but we avoid it here). + os.environ["HIP_VISIBLE_DEVICES"] = assigned + +def pytest_configure(config: pytest.Config) -> None: + """Register custom pytest markers and print attached GPUs to xdist workers.""" + config.addinivalue_line( + "markers", + "mosaic_gpu: tests that use Mosaic GPU (skipped on ROCm)", + ) + + # Real pytest hook (runs early in main + each xdist worker). + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") or "main" + + # xdist master: print planned mapping (worker stdout is often hidden) + numproc = int(getattr(getattr(config, "option", None), "numprocesses", 0) or 0) + if xdist_worker_name == "main" and numproc > 0: + hip0 = (os.environ.get("HIP_VISIBLE_DEVICES") or "").strip() + cuda_x = (os.environ.get("JAX_ENABLE_CUDA_XDIST") or "").strip() + tpu_x = (os.environ.get("JAX_ENABLE_TPU_XDIST") or "").strip() + rocm_x = (os.environ.get("JAX_ENABLE_ROCM_XDIST") or "").strip() + if cuda_x: + try: + ndev = int(cuda_x) + except ValueError: + ndev = 0 + if ndev > 0: + mapping = ", ".join(f"gw{i}->CUDA_VISIBLE_DEVICES={i % ndev}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif tpu_x: + mapping = ", ".join(f"gw{i}->TPU_VISIBLE_CHIPS={i}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif rocm_x: + try: + ndev = int(rocm_x) + except ValueError: + ndev = 0 + if ndev > 0: + mapping = ", ".join(f"gw{i}->HIP_VISIBLE_DEVICES={i % ndev}" for i in range(numproc)) + print(f"[DeviceVisibility] xdist planned mapping: {mapping}", flush=True) + elif hip0: + print(f"[DeviceVisibility] master HIP_VISIBLE_DEVICES={hip0}", flush=True) + + if os.environ.get("JAX_ENABLE_TPU_XDIST", None): + if xdist_worker_name.startswith("gw"): + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") + + elif num_cuda_devices := os.environ.get("JAX_ENABLE_CUDA_XDIST", None): + if xdist_worker_name.startswith("gw"): + num_cuda_devices = int(num_cuda_devices) + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault( + "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) + ) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 42d62340bd92..e25d4bbdf792 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -564,9 +564,8 @@ def make_ir_context() -> ir.Context: # multi threaded execution aborts the process if we try to register a new # dialect after this point. The dialect registry in a context is not thread # safe, and a fatal error is much better than a data race. - # jax_mlir_ext.enter_multi_threaded_execution(context) - # TODO(phawkins): clean up users who add their own dialects to JAX's contexts - # and enable this. + # if jaxlib_version >= (0, 8): + # jax_mlir_ext.enter_multi_threaded_execution(context) return context diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8e763c3d8e6a..25423ebba471 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -48,6 +48,11 @@ def approx_tanh(x: jax.Array) -> jax.Array: elif x.dtype == jnp.float32: asm = "tanh.approx.f32 $0, $1;" constraint = "f" + elif x.dtype == jnp.float64: + # f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64) + # CUDA does not have a PTX instruction for f64 approximate tanh + asm = "tanh.approx.f64 $0, $1;" + constraint = "d" else: raise TypeError(f"approx_tanh does not accept {x.dtype} arrays") @@ -119,6 +124,13 @@ def _elementwise_inline_asm_lowering( result_shape_dtypes, ): del result_shape_dtypes # Unused. + + # For ROCm, PTX inline assembly is not supported. For tanh.approx, we use + # Triton's __triton_hip_fast_tanhf (fast exp-based formula) for f32, and + # OCML's __ocml_tanh_f64 for f64. See: https://github.com/triton-lang/triton/pull/7780 + if ctx.context.platform == "rocm" and "tanh.approx" in asm: + return _approx_tanh_rocm_lowering(ctx, *args) + return tt_dialect.ElementwiseInlineAsmOp( [*map(mlir.aval_to_ir_type, ctx.avals_out)], asm, @@ -129,6 +141,86 @@ def _elementwise_inline_asm_lowering( ).result +def _approx_tanh_rocm_lowering( + ctx: lowering.LoweringRuleContext, + *args, +): + """Lower approx_tanh for ROCm. + + AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction. + + For f32 (and f16/bf16 via casting): We use Triton's __triton_hip_fast_tanhf + which implements a fast exp-based formula: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + See: https://github.com/triton-lang/triton/pull/7780 + + For f64: We use OCML's __ocml_tanh_f64 (AMD's Open Compute Math Library) + since fast_tanhf only supports f32. + """ + from jax._src.lib.mlir import ir + from jax._src.lib.mlir.dialects import arith as arith_dialect + + [arg] = args + [out_aval] = ctx.avals_out + in_dtype = ctx.avals_in[0].dtype + + # Helper to get IR type for a dtype + def dtype_to_ir_type(dtype): + dtype = jnp.dtype(dtype) + return mlir.dtype_to_ir_type(dtype) + + # f64: use __ocml_tanh_f64 (fast_tanhf only supports f32) + if in_dtype == jnp.float64: + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="", + libpath="", + symbol="__ocml_tanh_f64", + pure=True, + ) + return [result] + + # fast_tanhf only supports f32. For f16/bf16, cast to f32, compute, cast back. + needs_cast = in_dtype in (jnp.float16, jnp.bfloat16) + + if needs_cast: + # Cast input to f32 (extend) + f32_type = dtype_to_ir_type(jnp.float32) + if out_aval.shape: + f32_result_type = ir.RankedTensorType.get(out_aval.shape, f32_type) + else: + f32_result_type = f32_type + arg_f32 = arith_dialect.extf(f32_result_type, arg) + + # Call __triton_hip_fast_tanhf (fast exp-based implementation) + tanh_result = tt_dialect.extern_elementwise( + f32_result_type, + [arg_f32], + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + # Cast result back to original dtype (truncate) + out_type = mlir.aval_to_ir_type(out_aval) + result = arith_dialect.truncf(out_type, tanh_result) + else: + # f32: call __triton_hip_fast_tanhf directly + result_type = mlir.aval_to_ir_type(out_aval) + result = tt_dialect.extern_elementwise( + result_type, + list(args), + libname="libdevice", + libpath="", + symbol="__triton_hip_fast_tanhf", + pure=True, + ) + + return [result] + + def debug_barrier() -> None: """Synchronizes all kernel executions in the grid.""" return debug_barrier_p.bind() diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index df5e1433cd94..b9ed2fa67f05 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -25,6 +25,7 @@ import logging import math import os +from pathlib import Path import platform import re import sys @@ -392,6 +393,15 @@ def supported_dtypes(): def is_device_rocm(): return 'rocm' in xla_bridge.get_backend().platform_version +def get_rocm_version(): + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + version_path = Path(rocm_path) / ".info" / "version" + if not version_path.exists(): + raise FileNotFoundError(f"Expected ROCm version file at {version_path}") + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version @@ -613,7 +623,14 @@ def skip_on_devices(*disabled_devices, skip_reason=None): skip_reason: Optional custom skip message when test is skipped. """ if skip_reason is None: - skip_reason = "Skipped on devices with tags: " + ", ".join(disabled_devices) + skip_messages = { + ("gpu",): "Skipped on all GPUs.", + ("cpu",): "Skipped on CPU.", + ("tpu",): "Skipped on TPU.", + ("cuda",): "Skipped on CUDA GPUs.", + ("rocm",): "Skipped on ROCm GPUs.", + } + skip_reason = skip_messages.get(disabled_devices) return _device_filter(lambda: not test_device_matches(disabled_devices), skip_reason) def run_on_devices(*enabled_devices, skip_reason=None): @@ -624,9 +641,14 @@ def run_on_devices(*enabled_devices, skip_reason=None): skip_reason: Optional custom skip message when test is skipped. """ if skip_reason is None: - skip_reason = ( - "Skipped unless running on devices with tags: " + ", ".join(enabled_devices) - ) + device_specific_skip_reasons = { + ("cpu",): "Skipped: CPU-only test.", + ("tpu",): "Skipped: TPU-only test.", + ("gpu",): "Skipped: GPU-only test.", + ("rocm",): "Skipped: ROCm-only test.", + ("cuda",): "Skipped: CUDA-only test.", + } + skip_reason = device_specific_skip_reasons.get(enabled_devices) return _device_filter(lambda: test_device_matches(enabled_devices), skip_reason) def device_supports_buffer_donation(): diff --git a/jaxlib/BUILD b/jaxlib/BUILD index d9a3a619965c..639803e667d4 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -39,7 +39,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//jax:internal"], + default_visibility = ["//visibility:public"], ) package_group( diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index e10c08d54e9f..8866c7bea2fe 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -62,8 +62,8 @@ JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf); JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched); JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched); -JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuComplex, gpublasCgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpuDoubleComplex, gpublasZgetrfBatched); #undef JAX_GPU_DEFINE_GETRF_BATCHED // QR decomposition: geqrf @@ -101,8 +101,8 @@ JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf); JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched); JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched); -JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuComplex, gpublasCgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpuDoubleComplex, gpublasZgeqrfBatched); #undef JAX_GPU_DEFINE_GEQRF_BATCHED // Householder transformations: orgqr @@ -272,8 +272,8 @@ JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd); JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk); JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk); -JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); -JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); +JAX_GPU_DEFINE_SYRK(gpuComplex, gpublasCsyrk); +JAX_GPU_DEFINE_SYRK(gpuDoubleComplex, gpublasZsyrk); #undef JAX_GPU_DEFINE_SYRK // Singular Value Decomposition: gesvd diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 4e6c4ca9a7d4..be985a8a2306 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -446,6 +446,8 @@ inline constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) +#define HIPBLAS_V2 1 + // IWYU pragma: begin_exports #include "rocm/include/hip/hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" @@ -466,17 +468,11 @@ inline constexpr uint32_t kNumThreadsPerWarp = 32; // MIOpen lib. Remove when MIOpen support is complete. #define MIOPEN_STATUS_SUCCESS 0 -typedef hipFloatComplex gpuComplex; +typedef hipComplex gpuComplex; typedef hipDoubleComplex gpuDoubleComplex; -#if TF_ROCM_VERSION >= 70000 -typedef hipFloatComplex gpublasComplex; +typedef hipComplex gpublasComplex; typedef hipDoubleComplex gpublasDoubleComplex; -#else -typedef hipblasComplex gpublasComplex; -typedef hipblasDoubleComplex gpublasDoubleComplex; -#endif // TF_ROCM_VERSION >= 70000 - typedef struct hipsolverHandle_* gpusolverDnHandle_t; typedef hipblasFillMode_t gpublasFillMode_t; typedef hipsolverFillMode_t gpusolverFillMode_t; @@ -533,6 +529,7 @@ inline hipblasStatus_t gpublasCreate(gpublasHandle_t* handle) { return hipblasCreate(reinterpret_cast(handle)); } } // namespace jax::hip + #define gpublasCreate ::jax::hip::gpublasCreate #define gpublasSetStream hipblasSetStream #define gpublasSgeqrfBatched hipblasSgeqrfBatched @@ -589,6 +586,7 @@ inline hipsolverStatus_t gpusolverDnCreate(gpusolverDnHandle_t* handle) { return hipsolverCreate(reinterpret_cast(handle)); } } // namespace jax::hip + #define gpusolverDnCreate ::jax::hip::gpusolverDnCreate #define gpusolverDnSetStream hipsolverSetStream #define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo @@ -758,10 +756,10 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I #define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I #define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT -#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_MV_ALG_DEFAULT -#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_ALG_DEFAULT -#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_ALG_DEFAULT +#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_COOMV_ALG +#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_CSRMV_ALG1 +#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_COO_ALG1 +#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_CSR_ALG1 #define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO #define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE #define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE @@ -774,7 +772,7 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking #define gpuMalloc hipMalloc -#define gpuGetLastError hipGetLastError +#define gpuGetLastError hipExtGetLastError #define gpuGetErrorString hipGetErrorString #define gpuMemcpyAsync hipMemcpyAsync #define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 2e2757f2ecc4..7989c6c3f4c7 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,7 +20,6 @@ load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") -load("@nvidia_wheel_versions//:versions.bzl", "NVIDIA_WHEEL_VERSIONS") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND") load("@rocm_external_test_deps//:external_deps.bzl", "EXTERNAL_DEPS") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") @@ -461,7 +460,6 @@ def _jax_wheel_impl(ctx): if ctx.attr.platform_version == "": fail("platform_version must be set to a valid cuda version for cuda wheels") args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels - args.add("--nvidia_wheel_versions_data", NVIDIA_WHEEL_VERSIONS) # required for gpu wheels if ctx.attr.enable_rocm: args.add("--enable-rocm", "True") if ctx.attr.platform_version == "": diff --git a/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc index f1bf60bb0517..65ee1673ca7c 100644 --- a/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc +++ b/jaxlib/mlir/_mlir_libs/jax_mlir_ext.cc @@ -204,7 +204,6 @@ NB_MODULE(_jax_mlir_ext, m) { unwrap(registry)->insert(); unwrap(registry)->insert(); unwrap(registry)->insert(); - // For Mosaic GPU REGISTER_DIALECT(cf); REGISTER_DIALECT(gpu); diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 72f2ec6dae80..f266baeb4683 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -27,7 +27,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//visibility:public"], ) cc_library( @@ -404,6 +404,10 @@ nanobind_extension( "@nanobind", "@xla//xla/ffi/api:ffi", ], + linkopts = [ + "-L/opt/rocm/lib", + "-lamdhip64", + ], ) cc_library( @@ -531,6 +535,7 @@ nanobind_extension( srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":hip_gpu_kernel_helpers", ":py_client_gpu", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 05be1e81c858..e28e4927b81f 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -23,6 +23,7 @@ limitations under the License. #include "jaxlib/gpu/gpu_plugin_extension.h" #include "jaxlib/gpu/py_client_gpu.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" namespace nb = nanobind; @@ -96,6 +97,13 @@ nb::dict FfiHandlers() { return dict; } +int ROCmDeviceCount() { + int device_count = -1; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipInit(0))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipGetDeviceCount(&device_count))); + return device_count; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { @@ -122,5 +130,6 @@ NB_MODULE(rocm_plugin_extension, m) { return device_ordinal; }, nb::arg("data_value")); + m.def("get_device_count", &ROCmDeviceCount); } } // namespace jax diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index d033a490f547..b14e9171f952 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -312,7 +312,7 @@ def testCudaArrayInterfaceOnNonCudaFails(self): self.assertFalse(hasattr(x, "__cuda_array_interface__")) with self.assertRaisesRegex( AttributeError, - "__cuda_array_interface__ is only defined for .*GPU buffers.", + "__cuda_array_interface__ is only defined for GPU buffers.", ): _ = x.__cuda_array_interface__ diff --git a/tests/array_test.py b/tests/array_test.py index 61c30b9ed065..970b93050de4 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1591,7 +1591,7 @@ class RngShardingTest(jtu.JaxTestCase): # tests that the PRNGs are automatically sharded as expected @parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5)) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_1d(self, num_devices): @jax.jit def f(x): @@ -1625,7 +1625,7 @@ def f(x): "mesh_shape": mesh_shape, "pspec": pspec} for mesh_shape in [(3, 2), (4, 2), (2, 3)] for pspec in [P('x', None), P(None, 'y'), P('x', 'y')]) - @jtu.skip_on_devices("gpu") + @jtu.skip_on_devices("cuda") def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec): @jax.jit def f(x): diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 4bb611dcd842..61e79c00a09e 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -179,7 +179,7 @@ def f(weights, x, h_0, c_0): y_padded = y_ref[i, seq_lengths[i]:] np.testing.assert_allclose(y_padded, jnp.zeros_like(y_padded)) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("gpu") def test_struct_encoding_determinism(self): def f(k1, k2, k3, k4): batch_size = 1 @@ -213,8 +213,15 @@ def f(k1, k2, k3, k4): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', - stablehlo) + # Platform-specific binary encodings for RnnDescriptor + cuda_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"' + rocm_encoding = '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\008\\00\\00\\00\\00\\00\\00\\00\\1C\\00\\00\\00\\00\\00\\00\\00"' + + # Check that one of the expected encodings is present + if jtu.test_device_matches(["cuda"]): + self.assertIn(cuda_encoding, stablehlo) + elif jtu.test_device_matches(["rocm"]): + self.assertIn(rocm_encoding, stablehlo) # Note: Other LSTM tests that use `bidirectional=True` on ROCm are skipped # because of current numerical issues (as of ROCm 7.1.1). However, this diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 57fdfc4dda88..de260ec4ed93 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -759,6 +759,8 @@ def testSort(self, shape, dimension, arity, bdims, is_stable): # TODO Collapse # TODO Scatter + # b/183233858: variadic reduce-window not implemented on XLA:CUDA + @jtu.skip_on_devices("cuda") def test_variadic_reduce_window(self): # https://github.com/jax-ml/jax/discussions/9818 and # https://github.com/jax-ml/jax/issues/9837 diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 63c8a28bb76b..5b38275947ff 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -570,7 +570,12 @@ def testEighZeroDiagonal(self): np.linalg.norm(np.matmul(a, v) - w * v), 2.5 * eps * np.linalg.norm(a) ) + def testEighTinyNorm(self): + if jtu.is_device_rocm(): + # numerical errors seen as of ROCm 7.2 due to hipSolver issue + # TODO: re-enable the test once the hipSolver issue is fixed + self.skipTest("testEighNorm not supported on ROCm due to hipSOLVER issue") rng = jtu.rand_default(self.rng()) a = rng((300, 300), dtype=np.float32) eps = jnp.finfo(a.dtype).eps @@ -2371,6 +2376,7 @@ def testSelect(self, dtype): @jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)], dtype=float_types + complex_types) + @jtu.skip_on_devices("rocm") # Numerical errors on ROCm def test_tridiagonal_solve(self, shape, dtype): if dtype not in float_types and jtu.test_device_matches(["gpu"]): self.skipTest("Data type not supported on GPU") diff --git a/tests/pallas/gpu_paged_attention_test.py b/tests/pallas/gpu_paged_attention_test.py index 1b778c787a6d..6a1d8de22a78 100644 --- a/tests/pallas/gpu_paged_attention_test.py +++ b/tests/pallas/gpu_paged_attention_test.py @@ -112,6 +112,64 @@ class PagedAttentionKernelTest(PallasBaseTest): def setUp(self): super().setUp() + def _estimate_shared_memory_bytes(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Estimate shared memory usage for paged attention kernel.""" + dtype_size = jnp.dtype(dtype).itemsize + # Approximate calculation based on kernel's memory usage + # Q block: block_h * head_dim + # K/V blocks: pages_per_compute_block * page_size * head_dim + # Plus accumulators and intermediate values + block_k = pages_per_compute_block * page_size + estimated = dtype_size * ( + block_h * head_dim + # Q + 2 * block_k * head_dim + # K and V + block_h * block_k + # logits/attention weights + block_h * 8 # accumulators (m, l, etc.) in float32 + ) + return estimated + + def _adjust_params_for_shared_memory(self, block_h, pages_per_compute_block, + page_size, head_dim, dtype): + """Adjust parameters to fit within device shared memory limits. + + Uses XLA's DeviceDescription.shared_memory_per_block_optin() to query + the actual device capability rather than hardcoding values. + """ + try: + device = jax.local_devices()[0] + # Query XLA DeviceDescription for max shared memory per block + # This is exposed from stream_executor::DeviceDescription::shared_memory_per_block_optin() + max_smem = device.shared_memory_per_block_optin + except (AttributeError, IndexError): + # Fallback if XLA doesn't expose shared_memory_per_block_optin (older versions) + # or if no devices are available. Use conservative 48KB (safe for most GPUs). + max_smem = 48 * 1024 + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + # If within limits, no adjustment needed + if estimated <= max_smem: + return block_h, pages_per_compute_block, page_size + + # Try to reduce parameters to fit + while estimated > max_smem: + if pages_per_compute_block > 2: + pages_per_compute_block = pages_per_compute_block // 2 + elif page_size > 8: + page_size = page_size // 2 + elif block_h > 8: + block_h = block_h // 2 + else: + # Can't reduce further, will need to skip + return None, None, None + + estimated = self._estimate_shared_memory_bytes( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + return block_h, pages_per_compute_block, page_size + @jtu.sample_product( dtype=(jnp.float16,), page_size=(8, 16, 32), @@ -201,6 +259,17 @@ def test_quantized_paged_attention( if (quant_dtype == jnp.float8_e4m3fn and not jtu.is_cuda_compute_capability_at_least("8.9")): self.skipTest("Skipping since float8_e4m3fn is not supported on < sm89") + + # Check and adjust parameters if needed to fit device limits for ROCm + if jtu.is_device_rocm(): + adjusted = self._adjust_params_for_shared_memory( + block_h, pages_per_compute_block, page_size, head_dim, dtype) + + if adjusted == (None, None, None): + self.skipTest("Cannot adjust parameters to fit ROCm device shared memory limits") + + block_h, pages_per_compute_block, page_size = adjusted + max_kv_len = 2048 seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32) q, k_pages, v_pages, block_tables = _generate_qkv( @@ -218,7 +287,7 @@ def test_quantized_paged_attention( k_, k_scales = (_quantize(k_pages, quant_dtype) if quantize_k else (k_pages, None)) - v_, v_scales = (_quantize(k_pages, quant_dtype) + v_, v_scales = (_quantize(v_pages, quant_dtype) if quantize_v else (v_pages, None)) o = paged_attention.paged_attention( diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index 610965cb8429..dce9fc329187 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -51,20 +51,22 @@ def setUp(self): if jtu.test_device_matches(["rocm"]): self.skipTest("Mosaic not supported on ROCm currently.") + # Check mosaic support first (before GPU capability check) + if not mgpu.supports_cross_device_collectives(): + if jtu.test_device_matches(["rocm"]): + self.skipTest("Mosaic not supported on ROCm currently.") + else: + self.skipTest("NVSHMEM library unavailable.") if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") - if not mgpu.supports_cross_device_collectives(): - self.skipTest( - "Skip test since cross-device collectives are not supported" - " (either NVSHMEM is not available in multi-process mode, or mixed" - " mode is used).") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": self.skipTest("NVSHMEM doesn't work with the platform allocator.") super().setUp() - class PallasCallRemoteDMATest(TestCase): def test_remote_dma_basic(self): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e8628dda490c..b5a3de52dd49 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -74,6 +74,34 @@ def is_power_of_two(n: int) -> bool: return (n > 0) and (n & (n - 1) == 0) +def get_rocm_shared_memory_limit() -> int: + """Get the shared memory (LDS) limit in bytes for ROCm devices. + + Queries rocminfo to get the GROUP segment size dynamically. + Returns 64KB as default if rocminfo fails (MI100/MI200/MI300 all have 64KB LDS). + """ + try: + result = subprocess.run( + ['rocminfo'], capture_output=True, text=True, timeout=10 + ) + if result.returncode != 0: + return 64 * 1024 # Default if rocminfo fails + lines = result.stdout.split('\n') + for i, line in enumerate(lines): + if 'Segment:' in line and 'GROUP' in line: + if i + 1 < len(lines): + size_line = lines[i + 1] + # Match "Size: () KB" with case-insensitive KB check + match = re.search(r'Size:\s+(\d+)\s*\([^)]+\)\s*KB', size_line, re.IGNORECASE) + if match: + size_kb = int(match.group(1)) + return size_kb * 1024 # Convert KB to bytes + except Exception: + pass + # Default for AMD GPUs (MI100/MI200/MI300 all have 64KB LDS) + return 64 * 1024 + + def smem_on_tpu(): if jtu.test_device_matches(["tpu"]): return pltpu.SMEM @@ -1003,6 +1031,13 @@ def test_is_finite(self, dtype): # The original test worked only on fp32@TPU, have no way to test CUDA self.skipTest("Not tested on CUDA, todo for the respective team") + if jtu.test_device_matches(["cuda"]): + self.skipTest("Not tested on CUDA") # set this b/c this how the test was + # originally configured. Have no way to test cuda. + + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.8.0") + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( @@ -1053,6 +1088,9 @@ def test_is_finite_scalar(self, dtype): # The original test worked only on fp32@TPU, have no way to test CUDA self.skipTest("Not tested on CUDA, todo for the respective team") + if jtu.is_device_rocm(): + self.skipTest("is_finite not in Triton lowering for jax 0.8.0") + size = len(self.IS_FINITE_TEST_VALUES) @functools.partial( @@ -1872,6 +1910,47 @@ def kernel(o_ref): np.testing.assert_allclose(f(), kernel()) + @parameterized.parameters("float16", "bfloat16", "float32", "float64") + def test_approx_tanh(self, dtype): + self.skip_if_mosaic_gpu() + + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + if self.INTERPRET: + self.skipTest("approx_tanh is not supported in interpret mode") + + if (dtype == "bfloat16" and + jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") + + if dtype == "float64": + if jtu.test_device_matches(["cuda"]): + self.skipTest("f64 approx_tanh is only supported on ROCm") + + # Enable x64 for f64 test if not already enabled, restore after test + original_x64 = jax.config.x64_enabled + if dtype == "float64" and not original_x64: + jax.config.update("jax_enable_x64", True) + self.addCleanup(lambda: jax.config.update("jax_enable_x64", False)) + + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = plgpu_triton.approx_tanh(x_ref[...]) + + x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) + # We upcast to float32 because NumPy <2.0 does not handle custom dtypes + # properly. See https://github.com/jax-ml/jax/issues/11014. + np.testing.assert_allclose( + kernel(x).astype(jnp.float32), + jnp.tanh(x).astype(jnp.float32), + atol=5e-3, + rtol=5e-3, + ) + @parameterized.parameters( ((2, 4), (8,)), ((2, 4), (8, 1)), @@ -2085,12 +2164,22 @@ def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): if jtu.test_device_matches(["gpu"]): if dtype == jnp.bfloat16: self.skipTest("bfloat16 type are not supported on GPU") - if ( - math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) - > (256 * 256) * 2 - ): - self.skipTest("Shared memory size limit exceeded") - if (jax.local_devices()[0].shared_memory_per_block_optin == 99 * 1024 and + # Check shared memory limit: Triton loads lhs + rhs into shared memory + if jtu.is_device_rocm(): + # ROCm: use correct formula with dynamic limit from rocminfo + dtype_size = jnp.dtype(dtype).itemsize + shared_mem_bytes = (math.prod(lhs_shape) + math.prod(rhs_shape)) * dtype_size + shared_mem_limit = get_rocm_shared_memory_limit() + if shared_mem_bytes > shared_mem_limit: + self.skipTest("Shared memory size limit exceeded") + else: + # NVIDIA: keep original check + if ( + math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) + > (256 * 256) * 2 + ): + self.skipTest("Shared memory size limit exceeded") + if (jax.local_devices()[0].device_kind == "NVIDIA L4" and dtype == jnp.float32 and lhs_and_rhs_shape in [ ((128, 16), (128, 256)), diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 1842da7c1b25..44314e9034fe 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -963,7 +963,7 @@ def testMultivariateNormalCovariance(self): check_dtypes=False) @jtu.sample_product(method=['cholesky', 'eigh', 'svd']) - @jtu.skip_on_devices('cuda', 'tpu') # Some NaNs on accelerators. + @jtu.skip_on_devices('cuda', 'tpu') # Some NaNs on accelerators. ROCm supported def testMultivariateNormalSingularCovariance(self, method): # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293 mu = jnp.zeros((2,)) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 58adf3a42cf1..30240ff69c50 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -139,8 +139,6 @@ def test_csr_fromdense_ad(self, shape, dtype): @jax.default_matmul_precision("float32") def test_csr_matmul_ad(self, shape, dtype, bshape): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_matmul_ad not supported on ROCm due to hipSPARSE issue") csr_matmul = sparse_csr._csr_matvec if len(bshape) == 1 else sparse_csr._csr_matmat tol = {np.float32: 2E-5, np.float64: 1E-12, np.complex64: 1E-5, @@ -221,8 +219,6 @@ def test_csr_fromdense(self, shape, dtype): ) def test_csr_matvec(self, shape, dtype, transpose): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_matvec not supported on ROCm due to hipSPARSE issue") op = lambda M: M.T if transpose else M @@ -594,8 +590,6 @@ def test_coo_spmm(self, shape, dtype, transpose): @jtu.run_on_devices("gpu") def test_csr_spmv(self, shape, dtype, transpose): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_csr_spmv not supported on ROCm due to hipSPARSE issue") tol = {np.float32: 2E-5, np.float64: 2E-14} @@ -1049,8 +1043,6 @@ def test_transpose(self, shape, dtype, Obj): @jax.default_matmul_precision("float32") def test_matmul(self, shape, dtype, Obj, bshape): if jtu.is_device_rocm(): - # hipSPARSE segfault observed as of ROCm 7.2. - # TODO(ROCm): Re-enable once hipSPARSE issue is fixed. self.skipTest("test_matmul not supported on ROCm due to hipSPARSE issue") rng = sptu.rand_sparse(self.rng(), post=jnp.array) rng_b = jtu.rand_default(self.rng())