diff --git a/BUILDING.md b/BUILDING.md index f993def2f..e91edd118 100644 --- a/BUILDING.md +++ b/BUILDING.md @@ -65,6 +65,13 @@ pip install \ You may need to pass `--force-reinstall` to your `pip install` command if you already have an installation of the plugin packages. +## Prefetching dependencies (optional) + +The wheel build uses `bazel build` under the hood. When you run the build CLI +directly (e.g. `python jax_rocm_plugin/build/build.py build ...` or via the +Makefile/stack), you can pass `--bazel_options=--nobuild` to only fetch +dependencies (no compile). + ## Troubleshooting If you have an older version of Docker on your system, you might get an error @@ -203,7 +210,7 @@ The plugin repo pulls together code from several other repositories as part of i ### `jax-ml/jax` Pulled in [via Bazel](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L12) -and is only used to [build the rocm_jaxX_plugin wheel](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel#L26). +and is only used to [build the jax-rocmX-plugin wheel](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel#L26). Bazel applies a [handful of patches](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L14) to the kernel code when it pulls jax-ml/jax. That kernel code is mostly stuff that we share with Nvidia, changes to it from AMD are few and far in-between, diff --git a/jax_rocm_plugin/build/build.py b/jax_rocm_plugin/build/build.py index b8cb454b1..7a9e4408d 100755 --- a/jax_rocm_plugin/build/build.py +++ b/jax_rocm_plugin/build/build.py @@ -56,10 +56,25 @@ `python build/build.py requirements_update` """ -# Define the build target for each wheel. +# Define the build target for each wheel. Use jax_wheel targets so that +# "bazel build" is used (enables --nobuild for prefetch-only workflows). WHEEL_BUILD_TARGET_DICT = { - "jax-rocm-plugin": "//jaxlib_ext/tools:build_gpu_kernels_wheel", - "jax-rocm-pjrt": "//pjrt/tools:build_gpu_plugin_wheel", + "jax-rocm-plugin": "//jaxlib_ext/tools:jax_rocm7_plugin_wheel", + "jax-rocm-pjrt": "//pjrt/tools:jax_rocm7_pjrt_wheel", +} +WHEEL_BUILD_TARGET_EDITABLE_DICT = { + "jax-rocm-plugin": "//jaxlib_ext/tools:jax_rocm7_plugin_wheel_editable", + "jax-rocm-pjrt": "//pjrt/tools:jax_rocm7_pjrt_wheel_editable", +} +# Bazel output dir (under bazel-bin) for each wheel for copy step. +WHEEL_BAZEL_DIST_DIR = { + "jax-rocm-plugin": os.path.join("bazel-bin", "jaxlib_ext", "tools", "dist"), + "jax-rocm-pjrt": os.path.join("bazel-bin", "pjrt", "tools", "dist"), +} +# Wheel file prefix for glob when copying (e.g. jax_rocm7_plugin*.whl). +WHEEL_GLOB_PREFIX = { + "jax-rocm-plugin": "jax_rocm7_plugin", + "jax-rocm-pjrt": "jax_rocm7_pjrt", } @@ -425,7 +440,10 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - bazel_command_base.append("run") + if args.command == "requirements_update": + bazel_command_base.append("run") + else: + bazel_command_base.append("build") if args.python_version: # Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version @@ -478,6 +496,10 @@ async def main(): wheel_build_command_base = copy.deepcopy(bazel_command_base) + # So the wheel build action (BuildJaxWheel) can find tools like patchelf + # that are on PATH in the runner (e.g. manylinux container). + wheel_build_command_base.append("--action_env=PATH") + wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", @@ -642,6 +664,9 @@ async def main(): f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" ) + # Release version with no suffix (BUILDING.md: jax_rocm7_pjrt-X.X.X-...). + wheel_build_command_base.append("--repo_env=ML_WHEEL_TYPE=release") + # Append additional build options at the end to override any options set in # .bazelrc or above. if args.bazel_options: @@ -683,6 +708,20 @@ async def main(): sys.exit(1) wheel_build_command = copy.deepcopy(wheel_build_command_base) + # Pass build settings used by the jax_wheel rule (from @jax repo). + wheel_build_command.append("--@jax//jaxlib/tools:output_path=dist") + wheel_build_command.append( + f"--@jax//jaxlib/tools:jaxlib_git_hash={git_hash}" + ) + + if args.editable: + build_target = WHEEL_BUILD_TARGET_EDITABLE_DICT.get( + wheel, WHEEL_BUILD_TARGET_DICT[wheel] + ) + else: + build_target = WHEEL_BUILD_TARGET_DICT[wheel] + wheel_build_command.append(build_target) + print("\n") logger.info( "Building %s for %s %s...", @@ -691,43 +730,6 @@ async def main(): arch, ) - # Append the build target to the Bazel command. - build_target = WHEEL_BUILD_TARGET_DICT[wheel] - wheel_build_command.append(build_target) - - wheel_build_command.append("--") - - if args.editable: - logger.info("Building an editable build") - output_path = os.path.join(output_path, wheel) - wheel_build_command.append("--editable") - - wheel_build_command.append(f'--output_path="{output_path}"') - wheel_build_command.append(f"--cpu={target_cpu}") - - if "cuda" in wheel: - wheel_build_command.append("--enable-cuda=True") - if args.cuda_version: - cuda_major_version = args.cuda_version.split(".")[0] - else: - cuda_major_version = args.cuda_major_version - wheel_build_command.append(f"--platform_version={cuda_major_version}") - - if "rocm" in wheel: - wheel_build_command.append("--enable-rocm=True") - wheel_build_command.append(f"--platform_version={args.rocm_version}") - - wheel_build_command.append(f"--rocm_jax_git_hash={git_hash}") - - use_local_xla = extract_override_path(args.bazel_options, "xla") - use_local_jax = extract_override_path(args.bazel_options, "jax") - - if use_local_xla: - wheel_build_command.append(f"--use_local_xla={use_local_xla}") - - if use_local_jax: - wheel_build_command.append(f"--use_local_jax={use_local_jax}") - result = await executor.run( wheel_build_command.get_command_as_string(), args.dry_run, @@ -739,6 +741,28 @@ async def main(): f"Command failed with return code {result.return_code}" ) + # Copy built wheels from bazel-bin to user output_path. + output_path = args.output_path + for wheel in args.wheels.split(","): + if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: + wheel = "jax-" + wheel + if wheel not in WHEEL_BAZEL_DIST_DIR: + continue + bazel_dir = WHEEL_BAZEL_DIST_DIR[wheel] + if not os.path.isdir(bazel_dir): + logging.warning( + "Bazel output dir not found: %s (skipping copy)", bazel_dir + ) + continue + if args.editable: + src_dir = os.path.join(bazel_dir, WHEEL_GLOB_PREFIX[wheel]) + dst_dir = os.path.join(output_path, wheel) + if os.path.isdir(src_dir): + utils.copy_dir_recursively(src_dir, dst_dir) + else: + glob_pattern = f"{WHEEL_GLOB_PREFIX[wheel]}*.whl" + utils.copy_individual_files(bazel_dir, output_path, glob_pattern) + # Exit with success if all wheels in the list were built successfully. sys.exit(0) diff --git a/jax_rocm_plugin/build/rocm/tools/build_wheels.py b/jax_rocm_plugin/build/rocm/tools/build_wheels.py index 7211870b8..1a00e8b58 100644 --- a/jax_rocm_plugin/build/rocm/tools/build_wheels.py +++ b/jax_rocm_plugin/build/rocm/tools/build_wheels.py @@ -180,6 +180,8 @@ def build_plugin_wheel( "--output_path=%s" % output_dir, # Use roctracer (v1) instead of rocprofiler-sdk (v3) for profiling. "--bazel_options=--define=xla_rocm_profiler=v1", + # Release version with no suffix (BUILDING.md: jax_rocm7_pjrt-X.X.X-...). + "--bazel_options=--repo_env=ML_WHEEL_TYPE=release", ] # Add clang path if clang is used. @@ -205,9 +207,12 @@ def build_plugin_wheel( env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) LOG.info("Running %r from cwd=%r", cmd, plugin_path) - pattern = re.compile("Output wheel: (.+)\n") - - _run_scan_for_output(cmd, pattern, env=env, cwd=plugin_path, capture="stderr") + # With bazel build, wheels are copied to output_dir by build.py + result = subprocess.run(cmd, env=env, cwd=plugin_path, check=False) + if result.returncode != 0: + raise RuntimeError( + "Plugin wheel build failed with return code: %d" % result.returncode + ) # pylint: disable=R0913,R0917,too-many-locals @@ -246,7 +251,6 @@ def build_jaxlib_wheel( "--verbose", "--bazel_options=--action_env=HIPCC_COMPILE_FLAGS_APPEND=--offload-compress", "--bazel_options=--repo_env=ML_WHEEL_TYPE=release", - f"--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX=+rocm{version_string}", ] # Add clang path if clang is used. @@ -338,8 +342,13 @@ def to_cpy_ver(python_version): return "cp%d%d" % (int(tup[0]), int(tup[1])) -def fix_wheel(path, jax_path): - """Fix auditwheel compliance using fixwheel.py and auditwheel.""" +def fix_wheel(path, jax_path, wheelhouse_dir=None): + """Fix auditwheel compliance using fixwheel.py and auditwheel. + + auditwheel repair writes the repaired wheel to ./wheelhouse (relative to cwd). + Pass wheelhouse_dir so we run with cwd set to its parent; then do not copy + the original from dist_manylinux to avoid duplicates. + """ try: # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 # so use one of the CPythons in /opt to run @@ -355,7 +364,9 @@ def fix_wheel(path, jax_path): fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") cmd = ["python", fixwheel_path, path] - subprocess.run(cmd, check=True, env=env) + # So auditwheel repair writes to /wheelhouse; use plugin dir as cwd. + cwd = os.path.dirname(wheelhouse_dir) if wheelhouse_dir else jax_path + subprocess.run(cmd, check=True, env=env, cwd=cwd) LOG.info("Wheel fix completed successfully.") except subprocess.CalledProcessError as cpe: LOG.error("Subprocess failed with error: %s", cpe) @@ -437,7 +448,9 @@ def main(): python_versions = args.python_versions.split(",") manylinux_output_dir = "dist_manylinux" - wheelhouse_dir = "/wheelhouse/" + # Use plugin-relative wheelhouse so it works in Docker and when run on the host. + wheelhouse_dir = os.path.join(args.plugin_path, "wheelhouse") + os.makedirs(wheelhouse_dir, exist_ok=True) rocm_path = args.rocm_path if args.rocm_version: @@ -482,11 +495,11 @@ def main(): args.compiler, wheels="jax-rocm-pjrt", ) - # Fix PJRT wheel. + # Fix PJRT wheel (auditwheel repair writes repaired wheel to wheelhouse_dir). wheel_paths = find_wheels(full_output_path) for wheel_path in wheel_paths: if "pjrt" in os.path.basename(wheel_path).lower(): - fix_wheel(wheel_path, args.plugin_path) + fix_wheel(wheel_path, args.plugin_path, wheelhouse_dir) # Build plugin wheel for each Python version. for py in python_versions: @@ -502,19 +515,17 @@ def main(): args.compiler, wheels="jax-rocm-plugin", ) - # Fix plugin wheels for this Python version. + # Fix plugin wheels for this Python version (repaired wheels go to wheelhouse_dir). wheel_paths = find_wheels(full_output_path) for wheel_path in wheel_paths: base = os.path.basename(wheel_path) # Only fix plugin wheels, skip already-fixed PJRT wheel. if "plugin" in base.lower(): - fix_wheel(wheel_path, args.plugin_path) + fix_wheel(wheel_path, args.plugin_path, wheelhouse_dir) - # Copy plugin + PJRT wheels to wheelhouse. - wheel_paths = find_wheels(full_output_path) - for whl in wheel_paths: - LOG.info("Copying %s into %s", whl, wheelhouse_dir) - shutil.copy(whl, wheelhouse_dir) + # Repaired plugin/PJRT wheels are already in wheelhouse_dir from fix_wheel + # (auditwheel repair writes there). Do not copy originals from dist_manylinux + # to avoid duplicate wheels # Optionally build jaxlib wheel if --jax-path is provided. if args.jax_path: diff --git a/jax_rocm_plugin/build/tools/utils.py b/jax_rocm_plugin/build/tools/utils.py index 92b265caa..8bfd612f0 100644 --- a/jax_rocm_plugin/build/tools/utils.py +++ b/jax_rocm_plugin/build/tools/utils.py @@ -219,9 +219,14 @@ def get_gcc_major_version(gcc_path: str): def get_jax_configure_bazel_options(bazel_command: list[str]): """Returns the bazel options to be written to .jax_configure.bazelrc.""" - # Get the index of the "run" parameter. Build options will come after "run" so - # we find the index of "run" and filter everything after it. - start = bazel_command.index("run") + # Get the index of the command verb ("run" or "build"). Options come after it. + if "run" in bazel_command: + start = bazel_command.index("run") + elif "build" in bazel_command: + start = bazel_command.index("build") + else: + logging.error("Bazel command has neither 'run' nor 'build'") + return "" jax_configure_bazel_options = "" try: for i in range(start + 1, len(bazel_command)): @@ -260,6 +265,41 @@ def get_githash(): return "" +def copy_dir_recursively(src, dst): + """Copy a directory tree from src to dst.""" + if os.path.exists(dst): + shutil.rmtree(dst) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + relative_path = os.path.relpath(root, src) + dst_dir = os.path.join(dst, relative_path) + os.makedirs(dst_dir, exist_ok=True) + for f in files: + src_file = os.path.join(root, f) + dst_file = os.path.join(dst_dir, f) + shutil.copy2(src_file, dst_file) + logging.info("Editable wheel path: %s", dst) + + +def copy_individual_files(src: str, dst: str, glob_pattern: str): + """Copy files matching glob_pattern from src to dst.""" + import glob as glob_module + + os.makedirs(dst, exist_ok=True) + logging.debug( + "Copying files matching pattern %r from %r to %r", + glob_pattern, + src, + dst, + ) + for f in glob_module.glob(os.path.join(src, glob_pattern)): + dst_file = os.path.join(dst, os.path.basename(f)) + if os.path.exists(dst_file): + os.remove(dst_file) + shutil.copy2(f, dst_file) + logging.info("Distribution path: %s", dst_file) + + def _parse_string_as_bool(s): """Parses a string as a boolean value.""" lower = s.lower() diff --git a/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel b/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel index 6b86fbd6c..4670ed289 100644 --- a/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel +++ b/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel @@ -20,12 +20,16 @@ licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) -# Source files for the ROCm plugin wheel - used by both the tool and jax_wheel -ROCM_PLUGIN_SOURCES = [ +# Static files only (no rule outputs). Used for py_binary data and wheel_sources static_srcs. +ROCM_PLUGIN_STATIC_SRCS = [ "LICENSE.txt", "//jax_plugins/rocm:plugin_pyproject.toml", "//jax_plugins/rocm:plugin_setup.py", "//pjrt/python:version.py", +] + +# Full list for py_binary data (static + @jax targets so runfiles contain .so when using build_gpu_kernels_wheel). +ROCM_PLUGIN_SOURCES = ROCM_PLUGIN_STATIC_SRCS + [ "@jax//jaxlib/rocm:rocm_gpu_support", "@jax//jaxlib/rocm:rocm_plugin_extension", ] @@ -49,23 +53,65 @@ py_binary( ], ) +# Same script, no data. Used by jax_wheel so inputs come only from source_files +# (target config), avoiding a second full build in exec config. +py_binary( + name = "build_gpu_kernels_wheel_via_sources", + main = "build_gpu_kernels_wheel.py", + srcs = ["build_gpu_kernels_wheel.py"], + args = [ + "--xla-commit", + xla_commit, + "--jax-commit", + jax_commit, + ], + deps = [ + "build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + py_library( name = "build_utils", srcs = ["build_utils.py"], ) -# Wheel sources using the same list +# Wheel sources: static_srcs for plain files; data_srcs so collect_data_aspect gathers +# .so from @jax py_library/nanobind_extension (data/deps). Putting @jax targets in +# static_srcs does not expand—their DefaultInfo.files are empty. wheel_sources( name = "rocm_plugin_sources", - static_srcs = ROCM_PLUGIN_SOURCES, + static_srcs = ROCM_PLUGIN_STATIC_SRCS, + data_srcs = [ + "@jax//jaxlib/rocm:rocm_gpu_support", + "@jax//jaxlib/rocm:rocm_plugin_extension", + ], + py_srcs = [ + "@jax//jaxlib/rocm:rocm_gpu_support", + ], ) -# ROCm Plugin Wheel (jax_rocm7_plugin) - contains GPU kernels +# ROCm Plugin Wheel (jax_rocm7_plugin) - contains GPU kernels. +# Use via_sources binary; .so files come from source_files via data_srcs (collect_data_aspect). jax_wheel( name = "jax_rocm7_plugin_wheel", enable_rocm = True, platform_version = "7", source_files = [":rocm_plugin_sources"], - wheel_binary = ":build_gpu_kernels_wheel", + wheel_binary = ":build_gpu_kernels_wheel_via_sources", + wheel_name = "jax_rocm7_plugin", +) + +# Editable plugin wheel (same sources, editable = True) +jax_wheel( + name = "jax_rocm7_plugin_wheel_editable", + editable = True, + enable_rocm = True, + platform_version = "7", + source_files = [":rocm_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_via_sources", wheel_name = "jax_rocm7_plugin", ) diff --git a/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py b/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py index 383186804..3c5a526d3 100644 --- a/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py +++ b/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py @@ -30,6 +30,21 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib_ext.tools import build_utils + +def _find_patchelf(): + """Return path to patchelf, or None if not found. Prefer env, then PATH, then common paths.""" + exe = os.environ.get("PATCHELF") + if exe and os.path.isfile(exe) and os.access(exe, os.X_OK): + return exe + exe = shutil.which("patchelf") + if exe: + return exe + for path in ("/usr/bin/patchelf", "/usr/local/bin/patchelf"): + if os.path.isfile(path) and os.access(path, os.X_OK): + return path + return None + + parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--output_path", @@ -50,8 +65,8 @@ parser.add_argument( "--srcs", action="append", - help="Source files passed by jax_wheel macro. If provided, these are used " - "for config files. .so files always come from runfiles.", + help="Source files passed by jax_wheel macro. If provided, config and .so " + "files are taken from here; otherwise from runfiles.", ) parser.add_argument( "--cpu", default=None, required=True, help="Target CPU architecture. Required." @@ -196,8 +211,9 @@ def prepare_wheel_rocm(wheel_sources_path: pathlib.Path, *, cpu, rocm_version, s plugin_dir, xla_commit_hash, jax_commit_hash, get_rocm_jax_git_hash() ) - # Copy .so files: always from jax runfiles - for so_file in [ + # Copy .so files: from --srcs when present there, else from runfiles. + # jax_wheel may pass only plugin-repo files in --srcs; .so from @jax are in runfiles. + so_files = [ f"_linalg.{pyext}", f"_prng.{pyext}", f"_solver.{pyext}", @@ -206,8 +222,28 @@ def prepare_wheel_rocm(wheel_sources_path: pathlib.Path, *, cpu, rocm_version, s f"_rnn.{pyext}", f"_triton.{pyext}", f"rocm_plugin_extension.{pyext}", - ]: - shutil.copy(r.Rlocation(f"jax/jaxlib/rocm/{so_file}"), plugin_dir) + ] + for so_file in so_files: + try: + if srcs: + shutil.copy(find_src(srcs, so_file), plugin_dir) + else: + runfiles_path = r.Rlocation(f"jax/jaxlib/rocm/{so_file}") + if runfiles_path is None: + raise FileNotFoundError( + f"'{so_file}' not in runfiles (run with data=ROCM_PLUGIN_SOURCES)" + ) + shutil.copy(runfiles_path, plugin_dir) + except FileNotFoundError: + runfiles_path = r.Rlocation(f"jax/jaxlib/rocm/{so_file}") + if runfiles_path is None: + raise FileNotFoundError( + f"'{so_file}' not in --srcs and not in runfiles. " + "Ensure jax_wheel source_files includes " + "@jax//jaxlib/rocm:rocm_gpu_support and " + "@jax//jaxlib/rocm:rocm_plugin_extension." + ) from None + shutil.copy(runfiles_path, plugin_dir) # NOTE(mrodden): this is a hack to change/set rpath values # in the shared objects that are produced by the bazel build @@ -217,14 +253,12 @@ def prepare_wheel_rocm(wheel_sources_path: pathlib.Path, *, cpu, rocm_version, s # which won't be correct until we make changes to # the xla/tsl/jax plugin build - try: - subprocess.check_output(["which", "patchelf"]) - except subprocess.CalledProcessError as ex: - mesg = ( + patchelf_cmd = _find_patchelf() + if patchelf_cmd is None: + raise RuntimeError( "rocm plugin and kernel wheel builds require patchelf. " "please install 'patchelf' and run again" ) - raise RuntimeError(mesg) from ex files = [ f"_linalg.{pyext}", @@ -245,7 +279,7 @@ def prepare_wheel_rocm(wheel_sources_path: pathlib.Path, *, cpu, rocm_version, s if not perms & stat.S_IWUSR: fix_perms = True os.chmod(so_path, perms | stat.S_IWUSR) - subprocess.check_call(["patchelf", "--set-rpath", runpath, so_path]) + subprocess.check_call([patchelf_cmd, "--set-rpath", runpath, so_path]) if fix_perms: os.chmod(so_path, perms) diff --git a/jax_rocm_plugin/pjrt/tools/BUILD.bazel b/jax_rocm_plugin/pjrt/tools/BUILD.bazel index e62332e1b..f4edaeca2 100644 --- a/jax_rocm_plugin/pjrt/tools/BUILD.bazel +++ b/jax_rocm_plugin/pjrt/tools/BUILD.bazel @@ -50,6 +50,27 @@ py_binary( ], ) +# Same script, no data. Used by jax_wheel so inputs come only from source_files +# (target config), avoiding a second full build in exec config. +py_binary( + name = "build_gpu_plugin_wheel_via_sources", + main = "build_gpu_plugin_wheel.py", + srcs = ["build_gpu_plugin_wheel.py"], + args = [ + "--xla-commit", + xla_commit, + "--jax-commit", + jax_commit, + ], + deps = [ + "build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + py_library( name = "build_utils", srcs = ["build_utils.py"], @@ -68,6 +89,18 @@ jax_wheel( no_abi = True, platform_version = "7", source_files = [":rocm_pjrt_sources"], - wheel_binary = ":build_gpu_plugin_wheel", + wheel_binary = ":build_gpu_plugin_wheel_via_sources", + wheel_name = "jax_rocm7_pjrt", +) + +# Editable PJRT wheel (same sources, editable = True) +jax_wheel( + name = "jax_rocm7_pjrt_wheel_editable", + editable = True, + enable_rocm = True, + no_abi = True, + platform_version = "7", + source_files = [":rocm_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_via_sources", wheel_name = "jax_rocm7_pjrt", ) diff --git a/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py b/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py index 1d9c2c335..92366cac3 100644 --- a/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py +++ b/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py @@ -30,6 +30,20 @@ from bazel_tools.tools.python.runfiles import runfiles from pjrt.tools import build_utils + +def _find_patchelf(): + """Return path to patchelf, or None if not found. Prefer env, then PATH, then common paths.""" + exe = os.environ.get("PATCHELF") + if exe and os.path.isfile(exe) and os.access(exe, os.X_OK): + return exe + exe = shutil.which("patchelf") + if exe: + return exe + for path in ("/usr/bin/patchelf", "/usr/local/bin/patchelf"): + if os.path.isfile(path) and os.access(path, os.X_OK): + return path + return None + parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--sources_path", @@ -201,14 +215,12 @@ def prepare_rocm_plugin_wheel( # which won't be correct until we make changes to # the xla/tsl/jax plugin build - try: - subprocess.check_output(["which", "patchelf"]) - except subprocess.CalledProcessError as ex: - mesg = ( + patchelf_cmd = _find_patchelf() + if patchelf_cmd is None: + raise RuntimeError( "rocm plugin and kernel wheel builds require patchelf. " "please install 'patchelf' and run again" ) - raise RuntimeError(mesg) from ex shared_obj_path = os.path.join(plugin_dir, "xla_rocm_plugin.so") runpath = "$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib:/opt/rocm/lib" @@ -218,7 +230,7 @@ def prepare_rocm_plugin_wheel( if not perms & stat.S_IWUSR: fix_perms = True os.chmod(shared_obj_path, perms | stat.S_IWUSR) - subprocess.check_call(["patchelf", "--set-rpath", runpath, shared_obj_path]) + subprocess.check_call([patchelf_cmd, "--set-rpath", runpath, shared_obj_path]) if fix_perms: os.chmod(shared_obj_path, perms)