Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def build_dockers(
else:
dockerfiles.append(path)

rocm_ver_tag = "rocm%s" % "".join(rocm_version.split("."))
# Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions
rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")).replace("+", ".")
plugin_namespace = rocm_version[0]
if plugin_namespace == "6":
plugin_namespace = "60"
Expand Down
7 changes: 4 additions & 3 deletions jax_rocm_plugin/build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build_jaxlib_wheel(rocm_version, pyver_string):
f"--rocm_path=/opt/rocm-{rocm_version}",
f"--clang_path=/opt/rocm-{rocm_version}/llvm/bin/clang",
"--bazel_options=--repo_env=ML_WHEEL_TYPE=release",
f"--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX=\"+rocm{rocm_version}\"" \
f"--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX=\"+rocm{rocm_version.replace('+', '.')}\"" \
"&&",
"auditwheel", "repair", "--plat", "manylinux_2_27_x86_64", "--only-plat", "dist/jaxlib*.whl", "-w", "../", "&&",
]
Expand Down Expand Up @@ -85,7 +85,8 @@ def dist_wheels(
xla_path = os.path.abspath(xla_path)

# create manylinux image with requested ROCm installed
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
# Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "").replace("+", ".")

# Try removing the Docker image.
try:
Expand Down Expand Up @@ -172,7 +173,7 @@ def dist_wheels(
"-e",
"GIT_WORK_TREE=/repo",
"-e",
"ROCM_VERSION_EXTRA=" + rocm_version,
"ROCM_VERSION_EXTRA=" + rocm_version.replace("+", "."),
image,
"bash",
"-c",
Expand Down
22 changes: 15 additions & 7 deletions jax_rocm_plugin/build/rocm/tools/get_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import ssl
import subprocess
import sys
import urllib.parse
import urllib.request


Expand Down Expand Up @@ -207,22 +208,29 @@ def _install_therock(rocm_version, therock_path):
else:
os.makedirs(rocm_real_path)
tar_path = "/tmp/therock.tar.gz"
with urllib.request.urlopen(therock_path) as response:
# URL-encode special characters (e.g., '+' becomes '%2B')
# Include '%' in safe to avoid double-encoding already-encoded URLs
encoded_url = urllib.parse.quote(therock_path, safe=":/?&=%")
with urllib.request.urlopen(encoded_url) as response:
if response.status == 200:
with open(tar_path, "wb") as tar_file:
tar_file.write(response.read())
cmd = ["tar", "-xzf", tar_path, "-C", rocm_real_path]
LOG.info("Running %r", cmd)
subprocess.check_call(cmd)

os.symlink(rocm_real_path, rocm_sym_path, target_is_directory=True)
if not os.path.exists(rocm_sym_path):
os.symlink(rocm_real_path, rocm_sym_path, target_is_directory=True)

# Make a symlink to amdgcn to fix LLVM not being able to find binaries
os.symlink(
rocm_real_path + "/lib/llvm/amdgcn/",
rocm_real_path + "/amdgcn",
target_is_directory=True,
)
# Only create if it doesn't already exist (newer TheRock tarballs include it)
amdgcn_symlink = rocm_real_path + "/amdgcn"
if not os.path.exists(amdgcn_symlink):
os.symlink(
rocm_real_path + "/lib/llvm/amdgcn/",
amdgcn_symlink,
target_is_directory=True,
)


def _setup_internal_repo(system, rocm_version, job_name, build_num):
Expand Down
Loading