Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion BUILDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ pip install \
You may need to pass `--force-reinstall` to your `pip install` command if you
already have an installation of the plugin packages.

## Prefetching dependencies (optional)

The wheel build uses `bazel build` under the hood. When you run the build CLI
directly (e.g. `python jax_rocm_plugin/build/build.py build ...` or via the
Makefile/stack), you can pass `--bazel_options=--nobuild` to only fetch
dependencies (no compile).

## Troubleshooting

If you have an older version of Docker on your system, you might get an error
Expand Down Expand Up @@ -203,7 +210,7 @@ The plugin repo pulls together code from several other repositories as part of i
### `jax-ml/jax`

Pulled in [via Bazel](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L12)
and is only used to [build the rocm_jaxX_plugin wheel](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel#L26).
and is only used to [build the jax-rocmX-plugin wheel](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel#L26).
Bazel applies a [handful of patches](https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L14)
to the kernel code when it pulls jax-ml/jax. That kernel code is mostly stuff
that we share with Nvidia, changes to it from AMD are few and far in-between,
Expand Down
106 changes: 65 additions & 41 deletions jax_rocm_plugin/build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,25 @@
`python build/build.py requirements_update`
"""

# Define the build target for each wheel.
# Define the build target for each wheel. Use jax_wheel targets so that
# "bazel build" is used (enables --nobuild for prefetch-only workflows).
WHEEL_BUILD_TARGET_DICT = {
"jax-rocm-plugin": "//jaxlib_ext/tools:build_gpu_kernels_wheel",
"jax-rocm-pjrt": "//pjrt/tools:build_gpu_plugin_wheel",
"jax-rocm-plugin": "//jaxlib_ext/tools:jax_rocm7_plugin_wheel",
"jax-rocm-pjrt": "//pjrt/tools:jax_rocm7_pjrt_wheel",
}
WHEEL_BUILD_TARGET_EDITABLE_DICT = {
"jax-rocm-plugin": "//jaxlib_ext/tools:jax_rocm7_plugin_wheel_editable",
"jax-rocm-pjrt": "//pjrt/tools:jax_rocm7_pjrt_wheel_editable",
}
# Bazel output dir (under bazel-bin) for each wheel for copy step.
WHEEL_BAZEL_DIST_DIR = {
"jax-rocm-plugin": os.path.join("bazel-bin", "jaxlib_ext", "tools", "dist"),
"jax-rocm-pjrt": os.path.join("bazel-bin", "pjrt", "tools", "dist"),
}
# Wheel file prefix for glob when copying (e.g. jax_rocm7_plugin*.whl).
WHEEL_GLOB_PREFIX = {
"jax-rocm-plugin": "jax_rocm7_plugin",
"jax-rocm-pjrt": "jax_rocm7_pjrt",
}


Expand Down Expand Up @@ -425,7 +440,10 @@ async def main():
for option in args.bazel_startup_options:
bazel_command_base.append(option)

bazel_command_base.append("run")
if args.command == "requirements_update":
bazel_command_base.append("run")
else:
bazel_command_base.append("build")

if args.python_version:
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
Expand Down Expand Up @@ -478,6 +496,10 @@ async def main():

wheel_build_command_base = copy.deepcopy(bazel_command_base)

# So the wheel build action (BuildJaxWheel) can find tools like patchelf
# that are on PATH in the runner (e.g. manylinux container).
wheel_build_command_base.append("--action_env=PATH")

wheel_cpus = {
"darwin_arm64": "arm64",
"darwin_x86_64": "x86_64",
Expand Down Expand Up @@ -642,6 +664,9 @@ async def main():
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
)

# Release version with no suffix (BUILDING.md: jax_rocm7_pjrt-X.X.X-...).
wheel_build_command_base.append("--repo_env=ML_WHEEL_TYPE=release")

# Append additional build options at the end to override any options set in
# .bazelrc or above.
if args.bazel_options:
Expand Down Expand Up @@ -683,6 +708,20 @@ async def main():
sys.exit(1)

wheel_build_command = copy.deepcopy(wheel_build_command_base)
# Pass build settings used by the jax_wheel rule (from @jax repo).
wheel_build_command.append("--@jax//jaxlib/tools:output_path=dist")
wheel_build_command.append(
f"--@jax//jaxlib/tools:jaxlib_git_hash={git_hash}"
)

if args.editable:
build_target = WHEEL_BUILD_TARGET_EDITABLE_DICT.get(
wheel, WHEEL_BUILD_TARGET_DICT[wheel]
)
else:
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
wheel_build_command.append(build_target)

print("\n")
logger.info(
"Building %s for %s %s...",
Expand All @@ -691,43 +730,6 @@ async def main():
arch,
)

# Append the build target to the Bazel command.
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
wheel_build_command.append(build_target)

wheel_build_command.append("--")

if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")

wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")

if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")

if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")

wheel_build_command.append(f"--rocm_jax_git_hash={git_hash}")

use_local_xla = extract_override_path(args.bazel_options, "xla")
use_local_jax = extract_override_path(args.bazel_options, "jax")

if use_local_xla:
wheel_build_command.append(f"--use_local_xla={use_local_xla}")

if use_local_jax:
wheel_build_command.append(f"--use_local_jax={use_local_jax}")

result = await executor.run(
wheel_build_command.get_command_as_string(),
args.dry_run,
Expand All @@ -739,6 +741,28 @@ async def main():
f"Command failed with return code {result.return_code}"
)

# Copy built wheels from bazel-bin to user output_path.
output_path = args.output_path
for wheel in args.wheels.split(","):
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
wheel = "jax-" + wheel
if wheel not in WHEEL_BAZEL_DIST_DIR:
continue
bazel_dir = WHEEL_BAZEL_DIST_DIR[wheel]
if not os.path.isdir(bazel_dir):
logging.warning(
"Bazel output dir not found: %s (skipping copy)", bazel_dir
)
continue
if args.editable:
src_dir = os.path.join(bazel_dir, WHEEL_GLOB_PREFIX[wheel])
dst_dir = os.path.join(output_path, wheel)
if os.path.isdir(src_dir):
utils.copy_dir_recursively(src_dir, dst_dir)
else:
glob_pattern = f"{WHEEL_GLOB_PREFIX[wheel]}*.whl"
utils.copy_individual_files(bazel_dir, output_path, glob_pattern)

# Exit with success if all wheels in the list were built successfully.
sys.exit(0)

Expand Down
45 changes: 28 additions & 17 deletions jax_rocm_plugin/build/rocm/tools/build_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def build_plugin_wheel(
"--output_path=%s" % output_dir,
# Use roctracer (v1) instead of rocprofiler-sdk (v3) for profiling.
"--bazel_options=--define=xla_rocm_profiler=v1",
# Release version with no suffix (BUILDING.md: jax_rocm7_pjrt-X.X.X-...).
"--bazel_options=--repo_env=ML_WHEEL_TYPE=release",
]

# Add clang path if clang is used.
Expand All @@ -205,9 +207,12 @@ def build_plugin_wheel(
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])

LOG.info("Running %r from cwd=%r", cmd, plugin_path)
pattern = re.compile("Output wheel: (.+)\n")

_run_scan_for_output(cmd, pattern, env=env, cwd=plugin_path, capture="stderr")
# With bazel build, wheels are copied to output_dir by build.py
result = subprocess.run(cmd, env=env, cwd=plugin_path, check=False)
if result.returncode != 0:
raise RuntimeError(
"Plugin wheel build failed with return code: %d" % result.returncode
)


# pylint: disable=R0913,R0917,too-many-locals
Expand Down Expand Up @@ -246,7 +251,6 @@ def build_jaxlib_wheel(
"--verbose",
"--bazel_options=--action_env=HIPCC_COMPILE_FLAGS_APPEND=--offload-compress",
"--bazel_options=--repo_env=ML_WHEEL_TYPE=release",
f"--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX=+rocm{version_string}",
]

# Add clang path if clang is used.
Expand Down Expand Up @@ -338,8 +342,13 @@ def to_cpy_ver(python_version):
return "cp%d%d" % (int(tup[0]), int(tup[1]))


def fix_wheel(path, jax_path):
"""Fix auditwheel compliance using fixwheel.py and auditwheel."""
def fix_wheel(path, jax_path, wheelhouse_dir=None):
"""Fix auditwheel compliance using fixwheel.py and auditwheel.

auditwheel repair writes the repaired wheel to ./wheelhouse (relative to cwd).
Pass wheelhouse_dir so we run with cwd set to its parent; then do not copy
the original from dist_manylinux to avoid duplicates.
"""
try:
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
# so use one of the CPythons in /opt to run
Expand All @@ -355,7 +364,9 @@ def fix_wheel(path, jax_path):

fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
cmd = ["python", fixwheel_path, path]
subprocess.run(cmd, check=True, env=env)
# So auditwheel repair writes to <cwd>/wheelhouse; use plugin dir as cwd.
cwd = os.path.dirname(wheelhouse_dir) if wheelhouse_dir else jax_path
subprocess.run(cmd, check=True, env=env, cwd=cwd)
LOG.info("Wheel fix completed successfully.")
except subprocess.CalledProcessError as cpe:
LOG.error("Subprocess failed with error: %s", cpe)
Expand Down Expand Up @@ -437,7 +448,9 @@ def main():
python_versions = args.python_versions.split(",")

manylinux_output_dir = "dist_manylinux"
wheelhouse_dir = "/wheelhouse/"
# Use plugin-relative wheelhouse so it works in Docker and when run on the host.
wheelhouse_dir = os.path.join(args.plugin_path, "wheelhouse")
os.makedirs(wheelhouse_dir, exist_ok=True)

rocm_path = args.rocm_path
if args.rocm_version:
Expand Down Expand Up @@ -482,11 +495,11 @@ def main():
args.compiler,
wheels="jax-rocm-pjrt",
)
# Fix PJRT wheel.
# Fix PJRT wheel (auditwheel repair writes repaired wheel to wheelhouse_dir).
wheel_paths = find_wheels(full_output_path)
for wheel_path in wheel_paths:
if "pjrt" in os.path.basename(wheel_path).lower():
fix_wheel(wheel_path, args.plugin_path)
fix_wheel(wheel_path, args.plugin_path, wheelhouse_dir)

# Build plugin wheel for each Python version.
for py in python_versions:
Expand All @@ -502,19 +515,17 @@ def main():
args.compiler,
wheels="jax-rocm-plugin",
)
# Fix plugin wheels for this Python version.
# Fix plugin wheels for this Python version (repaired wheels go to wheelhouse_dir).
wheel_paths = find_wheels(full_output_path)
for wheel_path in wheel_paths:
base = os.path.basename(wheel_path)
# Only fix plugin wheels, skip already-fixed PJRT wheel.
if "plugin" in base.lower():
fix_wheel(wheel_path, args.plugin_path)
fix_wheel(wheel_path, args.plugin_path, wheelhouse_dir)

# Copy plugin + PJRT wheels to wheelhouse.
wheel_paths = find_wheels(full_output_path)
for whl in wheel_paths:
LOG.info("Copying %s into %s", whl, wheelhouse_dir)
shutil.copy(whl, wheelhouse_dir)
# Repaired plugin/PJRT wheels are already in wheelhouse_dir from fix_wheel
# (auditwheel repair writes there). Do not copy originals from dist_manylinux
# to avoid duplicate wheels

# Optionally build jaxlib wheel if --jax-path is provided.
if args.jax_path:
Expand Down
46 changes: 43 additions & 3 deletions jax_rocm_plugin/build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,14 @@ def get_gcc_major_version(gcc_path: str):

def get_jax_configure_bazel_options(bazel_command: list[str]):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so
# we find the index of "run" and filter everything after it.
start = bazel_command.index("run")
# Get the index of the command verb ("run" or "build"). Options come after it.
if "run" in bazel_command:
start = bazel_command.index("run")
elif "build" in bazel_command:
start = bazel_command.index("build")
else:
logging.error("Bazel command has neither 'run' nor 'build'")
return ""
jax_configure_bazel_options = ""
try:
for i in range(start + 1, len(bazel_command)):
Expand Down Expand Up @@ -260,6 +265,41 @@ def get_githash():
return ""


def copy_dir_recursively(src, dst):
"""Copy a directory tree from src to dst."""
if os.path.exists(dst):
shutil.rmtree(dst)
os.makedirs(dst, exist_ok=True)
for root, dirs, files in os.walk(src):
relative_path = os.path.relpath(root, src)
dst_dir = os.path.join(dst, relative_path)
os.makedirs(dst_dir, exist_ok=True)
for f in files:
src_file = os.path.join(root, f)
dst_file = os.path.join(dst_dir, f)
shutil.copy2(src_file, dst_file)
logging.info("Editable wheel path: %s", dst)


def copy_individual_files(src: str, dst: str, glob_pattern: str):
"""Copy files matching glob_pattern from src to dst."""
import glob as glob_module

os.makedirs(dst, exist_ok=True)
logging.debug(
"Copying files matching pattern %r from %r to %r",
glob_pattern,
src,
dst,
)
for f in glob_module.glob(os.path.join(src, glob_pattern)):
dst_file = os.path.join(dst, os.path.basename(f))
if os.path.exists(dst_file):
os.remove(dst_file)
shutil.copy2(f, dst_file)
logging.info("Distribution path: %s", dst_file)


def _parse_string_as_bool(s):
"""Parses a string as a boolean value."""
lower = s.lower()
Expand Down
Loading
Loading