Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ on:
description: 'JAX git ref (branch, tag, or commit) to checkout for jaxlib build'
required: false
type: string
default: 'rocm-jaxlib-v0.8.2'
default: 'rocm-jaxlib-v0.9.1'
builder-image:
required: false
type: string
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
# manage tests
clean: false
repository: rocm/jax
ref: rocm-jaxlib-v0.8.2
ref: rocm-jaxlib-v0.9.1
path: jax
- name: Apply patches to rocm/jax test repo
run: |
Expand Down Expand Up @@ -105,5 +105,5 @@ jobs:
run: |
python3 build/ci_build test \
"ghcr.io/rocm/jax-ubu24.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
--test-cmd "bash ci/jax_rbe/pr_setup.sh && ci/jax_rbe/pr_test.sh 0.8.2 3.12"
--test-cmd "bash ci/jax_rbe/pr_setup.sh && ci/jax_rbe/pr_test.sh 0.9.1 3.12"

20 changes: 10 additions & 10 deletions .github/workflows/llama-perf.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
name: Llama Performance Benchmarks

# This workflow runs Llama performance benchmarks
# using two different JAX versions: 0.6.0 & 0.8.2
# using two different JAX versions: 0.6.0 & 0.9.1
# For JAX 0.6.0: uses official released wheels
# and Docker image.
# For JAX 0.8.2: uses wheels and Docker image
# For JAX 0.9.1: uses wheels and Docker image
# built from the latest nightly results.

# PS: Ubuntu 24 & ROCm 7.2.0.
Expand Down Expand Up @@ -42,17 +42,17 @@ jobs:
strategy:
fail-fast: false
matrix:
jax-version: ["0.6.0", "0.8.2"]
jax-version: ["0.6.0", "0.9.1"]
include:
- jax-version: "0.6.0"
jaxlib-version: "0.6.0"
# There is no docker image with ROCm 7.2.0 and Jax 0.6.0
# use a ROCm 7.2.0 / Jax 0.8.2 docker image and update Jax
# use a ROCm 7.2.0 / Jax 0.9.1 docker image and update Jax
# before building TE. For the application workload Bazel installs
# its own Jax
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
- jax-version: "0.8.2"
jaxlib-version: "0.8.2"
- jax-version: "0.9.1"
jaxlib-version: "0.9.1"
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
env:
NVTE_FRAMEWORK: jax
Expand Down Expand Up @@ -148,16 +148,16 @@ jobs:
fail-fast: false
max-parallel: 1
matrix:
jax-version: ["0.6.0", "0.8.2"]
jax-version: ["0.6.0", "0.9.1"]
model-name: ["train_dense"]
include:
- jax-version: "0.6.0"
model-name: "train_dense"
jaxlib-version: "0.6.0"
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
- jax-version: "0.8.2"
- jax-version: "0.9.1"
model-name: "train_dense"
jaxlib-version: "0.8.2"
jaxlib-version: "0.9.1"
docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly"
steps:
- name: Checkout source repo
Expand Down Expand Up @@ -294,7 +294,7 @@ jobs:
model-name: "train_dense"
rocm-version: "7.2.0"
python-version: "3.12"
- jax-version: "0.8.2"
- jax-version: "0.9.1"
model-name: "train_dense"
rocm-version: "7.2.0"
python-version: "3.12"
Expand Down
43 changes: 39 additions & 4 deletions .github/workflows/pytest-results-to-db.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
name: Pytest Results to DB
name: Jax CI Test Results to DB

on:
schedule:
- cron: "10 * * * *"
workflow_dispatch:
inputs:
run-id:
required: true
source-repo:
description: "ORG/REPO"
required: false
default: "ROCm/jax"
type: string
source-github-run-id:
description: "Repo GitHub Run ID"
required: false
default: ""
type: string
secrets:
ROCM_JAX_DB_HOSTNAME:
Expand All @@ -16,9 +25,35 @@ on:
ROCM_JAX_DB_NAME:
required: true

concurrency:
group: jax-ci-test-results-to-db
cancel-in-progress: false

jobs:
upload-to-db:
process-test-results:
runs-on: mysqldb
steps:
- name: Checkout source
uses: actions/checkout@v4
- name: Set up Python environment
run: |
python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install mysql-connector-python
- name: Upload logs to MySQL database
env:
ROCM_JAX_DB_HOSTNAME: ${{ secrets.ROCM_JAX_DB_HOSTNAME }}
ROCM_JAX_DB_USERNAME: ${{ secrets.ROCM_JAX_DB_USERNAME }}
ROCM_JAX_DB_PASSWORD: ${{ secrets.ROCM_JAX_DB_PASSWORD }}
ROCM_JAX_DB_NAME: ${{ secrets.ROCM_JAX_DB_NAME }}
BUCKET: ${{ secrets.AMD_S3_BUCKET_NAME }}
AWS_ACCESS_KEY_ID: ${{ secrets.AMD_S3_BUCKET_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AMD_S3_BUCKET_SECRET_KEY }}
FILTER_REPO: ${{ inputs.source-repo }}
FILTER_RUN_ID: ${{ inputs.source-github-run-id }}
INPUT_GPU_ARCH: "MI350"
run: |
source venv/bin/activate
./ci/ingest_jax_ci_logs.sh

8 changes: 5 additions & 3 deletions .github/workflows/rocm-perf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
uses: actions/checkout@v4
with:
repository: ROCm/jax
ref: rocm-jaxlib-v0.8.2
ref: rocm-jaxlib-v0.9.1
path: jax

- name: Build plugin wheels
Expand Down Expand Up @@ -128,7 +128,9 @@ jobs:

- name: Analyze logs to compute median step time
run: |
pip install numpy --break-system-packages
python3 -m venv venv
source venv/bin/activate
pip install numpy
python3 build/analyze_maxtext_logs.py
cat summary.json

Expand Down Expand Up @@ -190,4 +192,4 @@ jobs:
--python-version "$PYTHON_VERSION" \
--rocm-version "$ROCM_VERSION" \
--gfx-version gfx942 \
--jax-version 0.8.2
--jax-version 0.9.1
2 changes: 1 addition & 1 deletion .github/workflows/test-and-upload.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
# TODO: Change the repo and ref once we figure out how exactly we're going to
# manage tests
repository: rocm/jax
ref: rocm-jaxlib-v0.8.2
ref: rocm-jaxlib-v0.9.1
path: jax
- name: Apply patches to rocm/jax test repo
run: |
Expand Down
2 changes: 1 addition & 1 deletion BUILDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ python3 build/ci_build test $TEST_IMAGE --test-cmd "pytest jax_rocm_plugin/tests
We keep unit tests in the `rocm/jax` repository, and you'll need to clone it
to run the regular JAX unit tests with ROCm,
```shell
git clone --depth 1 --branch rocm-jaxlib-v0.8.2 git@github.com:ROCm/jax.git
git clone --depth 1 --branch rocm-jaxlib-v0.9.1 git@github.com:ROCm/jax.git
# Each release of the ROCm plugin has a corresponding branch. You can find
# more at https://github.com/ROCm/rocm-jax/branches/all?query=rocm-jaxlib

Expand Down
9 changes: 6 additions & 3 deletions build/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,9 @@ def build_dockers(
commit_info = _get_commit_info_from_wheel()
dockerfiles = _apply_filters(docker_filters, "Dockerfile.jax")

rocm_ver_tag = "rocm%s" % "".join(rocm_version.split("."))
rocm_version_tag = "".join(rocm_version.split("."))
# Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions
rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")).replace("+", ".")
rocm_version_tag = "".join(rocm_version.split(".")).replace("+", ".")
plugin_namespace = rocm_version[0]
if plugin_namespace == "6":
plugin_namespace = "60"
Expand Down Expand Up @@ -456,7 +457,8 @@ def build_base_dockers(
push_latest=False,
):
dockerfiles = _apply_filters(docker_filters, "Dockerfile.base")
rocm_ver_tag = "rocm%s" % "".join(rocm_version.split("."))
# Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions
rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")).replace("+", ".")

extra_args = []
if add_llvm:
Expand Down Expand Up @@ -670,6 +672,7 @@ def parse_args():
"--llvm-version",
help="Set the version number of LLVM to be installed.",
type=int,
default="18",
)

testp = subp.add_parser("test")
Expand Down
4 changes: 2 additions & 2 deletions build/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
jaxlib==0.8.2
jax==0.8.2
jaxlib==0.9.1
jax==0.9.1
2 changes: 1 addition & 1 deletion ci/Dockerfile.maxtext
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ WORKDIR /maxtext
# Explicitly install jax,jaxlib to avoid pip pulling a newer version (e.g. 0.8.1)
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pip install -r requirements.txt && \
pip install jax==0.8.2 jaxlib==0.8.2 && pip freeze
pip install jax==0.9.1 jaxlib==0.9.1 && pip freeze
72 changes: 72 additions & 0 deletions ci/ingest_jax_ci_logs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash
set -euo pipefail

: "${INPUT_GPU_ARCH:?}"
: "${FILTER_REPO:=jax-ml/jax}"
: "${FILTER_RUN_ID:=}"

ROOT="jax-ci-test-logs/${FILTER_REPO}/"
CUTOFF="$(date -u -d '2 days ago' +%F)"

echo "Scanning S3 prefix: s3://${BUCKET}/${ROOT}"
echo "Cutoff date (UTC): ${CUTOFF}"

aws s3 ls "s3://${BUCKET}/${ROOT}" --recursive \
| awk '/\/_SUCCESS$/ {print $4}' \
| while read -r SUCCESS_KEY; do

PREFIX="${SUCCESS_KEY%/_SUCCESS}"

# Filter by run id (optional manual override)
if [[ -n "${FILTER_RUN_ID:-}" ]] && [[ "${PREFIX}" != *"_${FILTER_RUN_ID}_"* ]]; then
continue
fi

# Skip already ingested
if aws s3 ls "s3://${BUCKET}/${PREFIX}/_INGESTED" >/dev/null 2>&1; then
continue
fi

# Extract run_dir
RUN_DIR="$(basename "$(dirname "${PREFIX}")")"
RUN_DATE="${RUN_DIR:0:10}"

# Skip older than cutoff
if [[ "${RUN_DATE}" < "${CUTOFF}" ]]; then
continue
fi

# Skip if logs.tar.gz missing
if ! aws s3 ls "s3://${BUCKET}/${PREFIX}/logs.tar.gz" >/dev/null 2>&1; then
echo "Skipping ${PREFIX}: logs.tar.gz not found"
continue
fi

echo "Ingesting: ${PREFIX}"

WD="$(mktemp -d)"

aws s3 cp "s3://${BUCKET}/${PREFIX}/run-manifest.json" "${WD}/run-manifest.json"
aws s3 cp "s3://${BUCKET}/${PREFIX}/logs.tar.gz" "${WD}/logs.tar.gz"

mkdir -p "${WD}/logs_dir/extracted"
cp "${WD}/run-manifest.json" "${WD}/logs_dir/"
tar -xzf "${WD}/logs.tar.gz" -C "${WD}/logs_dir/extracted"

if python3 ci/upload_pytest_to_db.py \
--local_logs_dir "${WD}/logs_dir" \
--run-tag "ci-run" \
--gpu-tag "${INPUT_GPU_ARCH}" \
--artifact_uri "s3://${BUCKET}/${PREFIX}"
then
printf '' | aws s3 cp - "s3://${BUCKET}/${PREFIX}/_INGESTED"
echo "Marked _INGESTED: ${PREFIX}"
else
echo "Skipping ${PREFIX}: ingest failed"
fi

rm -rf "${WD}"
done

echo "Done"

2 changes: 1 addition & 1 deletion ci/jax_rbe/pr_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ python3 build/build.py build --wheels=jax-rocm-plugin --configure_only --python_
--config=rocm_rbe \
--noremote_accept_cached \
--//jax:build_jaxlib=false \
--action_env=TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" \
--action_env=TF_ROCM_AMDGPU_TARGETS="gfx9-generic,gfx9-4-generic,gfx1030,gfx11-generic,gfx12-generic" \
--test_verbose_timeout_warnings \
--test_output=errors \
//tests:core_test_gpu \
Expand Down
Loading
Loading