diff --git a/build/ci_build b/build/ci_build index 126e25eca6..3aaef3ded5 100755 --- a/build/ci_build +++ b/build/ci_build @@ -141,7 +141,8 @@ def build_dockers( else: dockerfiles.append(path) - rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")) + # Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions + rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")).replace("+", ".") plugin_namespace = rocm_version[0] if plugin_namespace == "6": plugin_namespace = "60" diff --git a/jax_rocm_plugin/build/rocm/ci_build b/jax_rocm_plugin/build/rocm/ci_build index 42b7948eb6..e72dc97810 100755 --- a/jax_rocm_plugin/build/rocm/ci_build +++ b/jax_rocm_plugin/build/rocm/ci_build @@ -56,7 +56,7 @@ def build_jaxlib_wheel(rocm_version, pyver_string): f"--rocm_path=/opt/rocm-{rocm_version}", f"--clang_path=/opt/rocm-{rocm_version}/llvm/bin/clang", "--bazel_options=--repo_env=ML_WHEEL_TYPE=release", - f"--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX=\"+rocm{rocm_version}\"" \ + f"--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX=\"+rocm{rocm_version.replace('+', '.')}\"" \ "&&", "auditwheel", "repair", "--plat", "manylinux_2_27_x86_64", "--only-plat", "dist/jaxlib*.whl", "-w", "../", "&&", ] @@ -85,7 +85,8 @@ def dist_wheels( xla_path = os.path.abspath(xla_path) # create manylinux image with requested ROCm installed - image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") + # Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions + image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "").replace("+", ".") # Try removing the Docker image. try: @@ -172,7 +173,7 @@ def dist_wheels( "-e", "GIT_WORK_TREE=/repo", "-e", - "ROCM_VERSION_EXTRA=" + rocm_version, + "ROCM_VERSION_EXTRA=" + rocm_version.replace("+", "."), image, "bash", "-c", diff --git a/jax_rocm_plugin/build/rocm/tools/get_rocm.py b/jax_rocm_plugin/build/rocm/tools/get_rocm.py index b75f1178d5..60dae5229a 100644 --- a/jax_rocm_plugin/build/rocm/tools/get_rocm.py +++ b/jax_rocm_plugin/build/rocm/tools/get_rocm.py @@ -29,6 +29,7 @@ import ssl import subprocess import sys +import urllib.parse import urllib.request @@ -207,7 +208,10 @@ def _install_therock(rocm_version, therock_path): else: os.makedirs(rocm_real_path) tar_path = "/tmp/therock.tar.gz" - with urllib.request.urlopen(therock_path) as response: + # URL-encode special characters (e.g., '+' becomes '%2B') + # Include '%' in safe to avoid double-encoding already-encoded URLs + encoded_url = urllib.parse.quote(therock_path, safe=":/?&=%") + with urllib.request.urlopen(encoded_url) as response: if response.status == 200: with open(tar_path, "wb") as tar_file: tar_file.write(response.read()) @@ -215,14 +219,18 @@ def _install_therock(rocm_version, therock_path): LOG.info("Running %r", cmd) subprocess.check_call(cmd) - os.symlink(rocm_real_path, rocm_sym_path, target_is_directory=True) + if not os.path.exists(rocm_sym_path): + os.symlink(rocm_real_path, rocm_sym_path, target_is_directory=True) # Make a symlink to amdgcn to fix LLVM not being able to find binaries - os.symlink( - rocm_real_path + "/lib/llvm/amdgcn/", - rocm_real_path + "/amdgcn", - target_is_directory=True, - ) + # Only create if it doesn't already exist (newer TheRock tarballs include it) + amdgcn_symlink = rocm_real_path + "/amdgcn" + if not os.path.exists(amdgcn_symlink): + os.symlink( + rocm_real_path + "/lib/llvm/amdgcn/", + amdgcn_symlink, + target_is_directory=True, + ) def _setup_internal_repo(system, rocm_version, job_name, build_num):