From 35e3bb6c1f6b56a099c1ceb35287252e7aee94e2 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 8 Jan 2026 18:52:12 -0600 Subject: [PATCH 01/10] Implement direct build in stack.py without Makefile --- stack.py | 209 ++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 184 insertions(+), 25 deletions(-) diff --git a/stack.py b/stack.py index f9b15fc8a9..6f64bf5b42 100644 --- a/stack.py +++ b/stack.py @@ -14,7 +14,8 @@ XLA_REPL_URL = "https://github.com/rocm/xla" DEFAULT_XLA_DIR = "../xla" -DEFAULT_KERNELS_JAX_DIR = "../jax" +DEFAULT_JAX_DIR = "../jax" + MAKE_TEMPLATE = r""" # gfx targets for which XLA and jax custom call kernels are built for @@ -50,12 +51,13 @@ ### -.PHONY: test clean install dist +.PHONY: test clean install dist all_wheels .default: dist dist: jax_rocm_plugin jax_rocm_pjrt +all_wheels: clean dist jaxlib_clean jaxlib jaxlib_install install jax_rocm_plugin: @@ -163,7 +165,22 @@ def find_clang(): return None -def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]: +def get_amdgpu_targets(): + """Attempt to detect AMD GPU targets using rocminfo.""" + try: + out = subprocess.check_output( + "rocminfo | grep -o -m 1 'gfx.*'", shell=True, stderr=subprocess.DEVNULL + ).decode() + target = out.strip() + if target: + return target + except Exception: + pass + # Default targets if rocminfo fails + return "gfx906,gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" + + +def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]: """Transforms relative to absolute paths. This is needed to properly support symbolic information remapping""" this_repo_root = os.path.dirname(os.path.realpath(__file__)) @@ -177,16 +194,16 @@ def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, st xla_path ), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'" - if kernels_jax_dir: + if jax_dir: kernels_jax_path = ( - kernels_jax_dir - if os.path.isabs(kernels_jax_dir) - else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}") + jax_dir + if os.path.isabs(jax_dir) + else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{jax_dir}") ) # pylint: disable=line-too-long assert os.path.isdir( kernels_jax_path - ), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" + ), f"XLA path (specified as '{jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" else: kernels_jax_path = None return this_repo_root, xla_path, kernels_jax_path @@ -262,10 +279,11 @@ def setup_development( xla_ref: str, xla_dir: str, test_jax_ref: str, - kernels_jax_dir: str, + jax_dir: str, rebuild_makefile: bool = False, fix_bazel_symbols: bool = False, rocm_path: str = "/opt/rocm", + write_makefile: bool = True, ): """Clone jax and xla repos, and set up Makefile for developers""" @@ -286,9 +304,9 @@ def setup_development( # create build/install/test script makefile_path = "./jax_rocm_plugin/Makefile" - if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols: + if (rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols) and write_makefile: this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths( - xla_dir, kernels_jax_dir + xla_dir, jax_dir ) if fix_bazel_symbols: plugin_bazel_options = "${PLUGIN_SYMBOLS}" @@ -346,6 +364,133 @@ def setup_development( mf.write(makefile_content) +def build_and_install( + xla_ref: str, + xla_dir: str, + test_jax_ref: str, + jax_dir: str, + rebuild_makefile: bool = False, + fix_bazel_symbols: bool = False, + rocm_path: str = "/opt/rocm", +): + """Run develop setup, then build all wheels and install jax""" + # Uninstall existing packages first + print("Uninstalling existing JAX packages...") + subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False) + subprocess.run( + [ + "python3", + "-m", + "pip", + "uninstall", + "jaxlib", + "jax-rocm-pjrt", + "jax-rocm-plugin", + "jax-plugin", + "-y", + ], + check=False, + ) + + # Setup repos (but don't necessarily generate Makefile) + setup_development( + xla_ref=xla_ref, + xla_dir=xla_dir, + test_jax_ref=test_jax_ref, + jax_dir=jax_dir, + rebuild_makefile=rebuild_makefile, + fix_bazel_symbols=fix_bazel_symbols, + rocm_path=rocm_path, + write_makefile=False, # Don't generate Makefile for direct build + ) + + this_repo_root, xla_path, jax_path = _resolve_relative_paths(xla_dir, jax_dir) + amdgpu_targets = get_amdgpu_targets() + clang_path = find_clang() or "/usr/lib/llvm-18/bin/clang" + + # ROCm version detection + try: + with open(os.path.join(rocm_path, ".info", "version"), encoding="utf-8") as f: + full_version = f.readline().strip() + rocm_version = full_version[0] + if rocm_version == "6": + rocm_version = "60" + except Exception: + rocm_version = "7" + + # Bazel options + bazel_options = [f"--override_repository=xla={xla_path}"] + if fix_bazel_symbols: + bazel_options.extend(["--strip=never", "--copt=-g3", "--cxxopt=-g3"]) + + plugin_bazel_options = list(bazel_options) + if fix_bazel_symbols: + plugin_bazel_options.append( + f"--copt=-fdebug-prefix-map=/proc/self/cwd={this_repo_root}/jax_rocm_plugin" + ) + plugin_bazel_options.append( + f"--cxxopt=-fdebug-prefix-map=/proc/self/cwd={this_repo_root}/jax_rocm_plugin" + ) + + jax_override = f"--override_repository=jax={jax_path}" if jax_path else "" + + def run_build(cwd, wheels, extra_options=None): + cmd = [ + "python3", + "./build/build.py", + "build", + "--use_clang=true", + f"--wheels={wheels}", + "--target_cpu_features=native", + f"--rocm_path={rocm_path}", + f"--rocm_version={rocm_version}", + f"--rocm_amdgpu_targets={amdgpu_targets}", + f"--clang_path={clang_path}", + "--verbose", + ] + opts = extra_options or bazel_options + for opt in opts: + cmd.extend(["--bazel_options", opt]) + if jax_override: + cmd.extend(["--bazel_options", jax_override]) + + print(f"Building {wheels} in {cwd}...") + subprocess.check_call(cmd, cwd=cwd) + + # 1. Build and install jaxlib + if jax_path: + # Clean jaxlib dist + subprocess.run(f"rm -f {jax_path}/dist/*", shell=True) + run_build(jax_path, "jaxlib") + print(f"Installing jaxlib from {jax_path}/dist...") + subprocess.run(f"pip install --force-reinstall {jax_path}/dist/*", shell=True) + + # 2. Build plugin and pjrt + subprocess.run("rm -rf dist", shell=True, cwd=this_repo_root) + run_build( + os.path.join(this_repo_root, "jax_rocm_plugin"), + "jax-rocm-plugin", + plugin_bazel_options, + ) + run_build( + os.path.join(this_repo_root, "jax_rocm_plugin"), + "jax-rocm-pjrt", + plugin_bazel_options, + ) + + # 3. Install plugin and pjrt + print("Installing jax-rocm-plugin and jax-rocm-pjrt...") + subprocess.run( + f"pip install --force-reinstall {this_repo_root}/jax_rocm_plugin/dist/*", + shell=True, + ) + + # 4. Install JAX from source + if jax_path: + print(f"Installing JAX from {jax_path}...") + subprocess.check_call(["pip", "install", "."], cwd=jax_path) + + def dev_docker(rm): """Start a docker container for local plugin development""" cur_abs_path = os.path.abspath(os.curdir) @@ -401,18 +546,19 @@ def parse_args(): subp = p.add_subparsers(dest="action", required=True) - dev = subp.add_parser("develop") - dev.add_argument( + # Common arguments for develop and build + common = argparse.ArgumentParser(add_help=False) + common.add_argument( "--rebuild-makefile", help="Force rebuild of Makefile from template.", action="store_true", ) - dev.add_argument( + common.add_argument( "--xla-ref", help="XLA commit reference to checkout on clone", default=XLA_REPO_REF, ) - dev.add_argument( + common.add_argument( "--xla-dir", help=( "Set the XLA path in the Makefile. This must either be a path " @@ -420,22 +566,21 @@ def parse_args(): ), default=DEFAULT_XLA_DIR, ) - dev.add_argument( + common.add_argument( "--jax-ref", help="JAX commit reference to checkout on clone", default=TEST_JAX_REPO_REF, ) - dev.add_argument( - "--kernel-jax-dir", + common.add_argument( + "--jax-dir", help=( "If you want to use a local JAX directory for building the " "plugin kernels wheel (jax_rocm7_plugin), the path to the " - "directory of repo. Defaults to %s" % DEFAULT_KERNELS_JAX_DIR + "directory of repo. Defaults to %s" % DEFAULT_JAX_DIR ), - default=DEFAULT_KERNELS_JAX_DIR, + default=DEFAULT_JAX_DIR, ) - - dev.add_argument( + common.add_argument( "--fix-bazel-symbols", help="When this option is enabled, the script assumes you need to build " "code in a release with symbolic info configuration to alleviate debugging. " @@ -443,13 +588,17 @@ def parse_args(): "links to corresponding workspaces pointing to bazel's dependencies storage.", action="store_true", ) - - dev.add_argument( + common.add_argument( "--rocm-path", help="Location of the ROCm to use for building Jax", default="/opt/rocm", ) + subp.add_parser("develop", parents=[common], help="Setup development environment") + subp.add_parser( + "build", parents=[common], help="Setup, build all wheels, and install JAX" + ) + doc_parser = subp.add_parser("docker") doc_parser.add_argument( "--rm", @@ -469,11 +618,21 @@ def main(): xla_ref=args.xla_ref, xla_dir=args.xla_dir, test_jax_ref=args.jax_ref, - kernels_jax_dir=args.kernel_jax_dir, + jax_dir=args.jax_dir, rebuild_makefile=args.rebuild_makefile, fix_bazel_symbols=args.fix_bazel_symbols, rocm_path=args.rocm_path, ) + elif args.action == "build": + build_and_install( + xla_ref=args.xla_ref, + xla_dir=args.xla_dir, + test_jax_ref=args.jax_ref, + jax_dir=args.jax_dir, + rebuild_makefile=True, + fix_bazel_symbols=args.fix_bazel_symbols, + rocm_path=args.rocm_path, + ) if __name__ == "__main__": From 9b92532f8c04e52350fafba6811bbbf96ffaeae5 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 8 Jan 2026 19:02:05 -0600 Subject: [PATCH 02/10] Fix --bazel_options argument parsing and refine build paths in stack.py --- stack.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/stack.py b/stack.py index 6f64bf5b42..14e2071b54 100644 --- a/stack.py +++ b/stack.py @@ -450,9 +450,9 @@ def run_build(cwd, wheels, extra_options=None): ] opts = extra_options or bazel_options for opt in opts: - cmd.extend(["--bazel_options", opt]) + cmd.append(f"--bazel_options={opt}") if jax_override: - cmd.extend(["--bazel_options", jax_override]) + cmd.append(f"--bazel_options={jax_override}") print(f"Building {wheels} in {cwd}...") subprocess.check_call(cmd, cwd=cwd) @@ -460,20 +460,21 @@ def run_build(cwd, wheels, extra_options=None): # 1. Build and install jaxlib if jax_path: # Clean jaxlib dist - subprocess.run(f"rm -f {jax_path}/dist/*", shell=True) + subprocess.run(f"rm -f {jax_path}/dist/*", shell=True, check=True) run_build(jax_path, "jaxlib") print(f"Installing jaxlib from {jax_path}/dist...") - subprocess.run(f"pip install --force-reinstall {jax_path}/dist/*", shell=True) + subprocess.run(f"pip install --force-reinstall {jax_path}/dist/*", shell=True, check=True) # 2. Build plugin and pjrt - subprocess.run("rm -rf dist", shell=True, cwd=this_repo_root) + plugin_path = os.path.join(this_repo_root, "jax_rocm_plugin") + subprocess.run("rm -rf dist", shell=True, cwd=plugin_path, check=True) run_build( - os.path.join(this_repo_root, "jax_rocm_plugin"), + plugin_path, "jax-rocm-plugin", plugin_bazel_options, ) run_build( - os.path.join(this_repo_root, "jax_rocm_plugin"), + plugin_path, "jax-rocm-pjrt", plugin_bazel_options, ) @@ -481,8 +482,9 @@ def run_build(cwd, wheels, extra_options=None): # 3. Install plugin and pjrt print("Installing jax-rocm-plugin and jax-rocm-pjrt...") subprocess.run( - f"pip install --force-reinstall {this_repo_root}/jax_rocm_plugin/dist/*", + f"pip install --force-reinstall {plugin_path}/dist/*", shell=True, + check=True, ) # 4. Install JAX from source From 9965e81c777e9ebbef276c5961dd675037432194 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Thu, 8 Jan 2026 19:28:34 -0600 Subject: [PATCH 03/10] Fix pylint and CI lint checks in stack.py --- stack.py | 63 +++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/stack.py b/stack.py index 14e2071b54..da2800722d 100644 --- a/stack.py +++ b/stack.py @@ -168,16 +168,21 @@ def find_clang(): def get_amdgpu_targets(): """Attempt to detect AMD GPU targets using rocminfo.""" try: + cmd = "rocminfo | grep -o -m 1 'gfx.*'" out = subprocess.check_output( - "rocminfo | grep -o -m 1 'gfx.*'", shell=True, stderr=subprocess.DEVNULL + cmd, shell=True, stderr=subprocess.DEVNULL ).decode() target = out.strip() if target: return target - except Exception: + except (subprocess.SubprocessError, OSError): pass # Default targets if rocminfo fails - return "gfx906,gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" + targets = ( + "gfx906,gfx908,gfx90a,gfx942,gfx950,gfx1030," + + "gfx1100,gfx1101,gfx1200,gfx1201" + ) + return targets def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]: @@ -190,9 +195,10 @@ def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]: if os.path.isabs(xla_dir) else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{xla_dir}") ) - assert os.path.isdir( - xla_path - ), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'" + assert os.path.isdir(xla_path), ( + f"XLA path (specified as '{xla_dir}') doesn't resolve to " + f"existing directory at '{xla_path}'" + ) if jax_dir: kernels_jax_path = ( @@ -201,9 +207,10 @@ def _resolve_relative_paths(xla_dir: str, jax_dir: str) -> tuple[str, str, str]: else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{jax_dir}") ) # pylint: disable=line-too-long - assert os.path.isdir( - kernels_jax_path - ), f"XLA path (specified as '{jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" + assert os.path.isdir(kernels_jax_path), ( + f"XLA path (specified as '{jax_dir}') doesn't resolve to " + f"existing directory at '{kernels_jax_path}'" + ) else: kernels_jax_path = None return this_repo_root, xla_path, kernels_jax_path @@ -234,7 +241,7 @@ def _add_externals_symlink(this_repo_root: str, xla_path: str, kernels_jax_path: print( f"Bazelisk is detected (bazel=={v}), proceeding with creation of symlinks" ) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught print( "WARNING: Bazelisk is NOT detected and a wrapper for specific bazel " "versions isn't implemented. Symlinks to '$(bazel info output_base)/external' " @@ -263,7 +270,7 @@ def _make_external(wrkspace: str): .stdout.decode("utf-8") .rstrip() ) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught print(f"Failed to query 'bazel info output_base' for '{wrkspace}':{e}") return _link(f"{output_base}/external", f"{wrkspace}/external") @@ -304,7 +311,8 @@ def setup_development( # create build/install/test script makefile_path = "./jax_rocm_plugin/Makefile" - if (rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols) and write_makefile: + rebuild = rebuild_makefile or not os.path.exists(makefile_path) + if (rebuild or fix_bazel_symbols) and write_makefile: this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths( xla_dir, jax_dir ) @@ -314,7 +322,9 @@ def setup_development( custom_options = " ${CFG_RELEASE_WITH_SYM}" _add_externals_symlink(this_repo_root, xla_path, kernels_jax_path) else: # not modifying the build unless asked - plugin_bazel_options, jaxlib_bazel_options, custom_options = "", "", "" + plugin_bazel_options = "" + jaxlib_bazel_options = "" + custom_options = "" # try to detect the namespace version from the ROCm version # this is expected to throw an exception if the specified ROCm path is invalid, for example @@ -376,7 +386,10 @@ def build_and_install( """Run develop setup, then build all wheels and install jax""" # Uninstall existing packages first print("Uninstalling existing JAX packages...") - subprocess.run(["python3", "-m", "pip", "uninstall", "jax", "-y"], check=False) + subprocess.run( + ["python3", "-m", "pip", "uninstall", "jax", "-y"], + check=False, + ) subprocess.run( [ "python3", @@ -410,12 +423,13 @@ def build_and_install( # ROCm version detection try: - with open(os.path.join(rocm_path, ".info", "version"), encoding="utf-8") as f: + version_file = os.path.join(rocm_path, ".info", "version") + with open(version_file, encoding="utf-8") as f: full_version = f.readline().strip() rocm_version = full_version[0] if rocm_version == "6": rocm_version = "60" - except Exception: + except (OSError, IndexError): rocm_version = "7" # Bazel options @@ -425,12 +439,11 @@ def build_and_install( plugin_bazel_options = list(bazel_options) if fix_bazel_symbols: - plugin_bazel_options.append( - f"--copt=-fdebug-prefix-map=/proc/self/cwd={this_repo_root}/jax_rocm_plugin" - ) - plugin_bazel_options.append( - f"--cxxopt=-fdebug-prefix-map=/proc/self/cwd={this_repo_root}/jax_rocm_plugin" - ) + plugin_root = f"{this_repo_root}/jax_rocm_plugin" + map_copt = f"--copt=-fdebug-prefix-map=/proc/self/cwd={plugin_root}" + plugin_bazel_options.append(map_copt) + map_cxxopt = f"--cxxopt=-fdebug-prefix-map=/proc/self/cwd={plugin_root}" + plugin_bazel_options.append(map_cxxopt) jax_override = f"--override_repository=jax={jax_path}" if jax_path else "" @@ -463,7 +476,11 @@ def run_build(cwd, wheels, extra_options=None): subprocess.run(f"rm -f {jax_path}/dist/*", shell=True, check=True) run_build(jax_path, "jaxlib") print(f"Installing jaxlib from {jax_path}/dist...") - subprocess.run(f"pip install --force-reinstall {jax_path}/dist/*", shell=True, check=True) + subprocess.run( + f"pip install --force-reinstall {jax_path}/dist/*", + shell=True, + check=True, + ) # 2. Build plugin and pjrt plugin_path = os.path.join(this_repo_root, "jax_rocm_plugin") From 21840fc8e4f275dd6ec954f15ece9f7d6c695b86 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Fri, 9 Jan 2026 11:24:58 -0600 Subject: [PATCH 04/10] Add --debug build support to stack.py --- stack.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/stack.py b/stack.py index da2800722d..c7bc19a9d2 100644 --- a/stack.py +++ b/stack.py @@ -291,6 +291,7 @@ def setup_development( fix_bazel_symbols: bool = False, rocm_path: str = "/opt/rocm", write_makefile: bool = True, + debug: bool = False, ): """Clone jax and xla repos, and set up Makefile for developers""" @@ -312,11 +313,16 @@ def setup_development( # create build/install/test script makefile_path = "./jax_rocm_plugin/Makefile" rebuild = rebuild_makefile or not os.path.exists(makefile_path) - if (rebuild or fix_bazel_symbols) and write_makefile: + if (rebuild or fix_bazel_symbols or debug) and write_makefile: this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths( xla_dir, jax_dir ) - if fix_bazel_symbols: + if debug: + plugin_bazel_options = "${PLUGIN_SYMBOLS}" + jaxlib_bazel_options = "${JAXLIB_SYMBOLS}" + custom_options = " ${CFG_DEBUG}" + _add_externals_symlink(this_repo_root, xla_path, kernels_jax_path) + elif fix_bazel_symbols: plugin_bazel_options = "${PLUGIN_SYMBOLS}" jaxlib_bazel_options = "${JAXLIB_SYMBOLS}" custom_options = " ${CFG_RELEASE_WITH_SYM}" @@ -382,6 +388,7 @@ def build_and_install( rebuild_makefile: bool = False, fix_bazel_symbols: bool = False, rocm_path: str = "/opt/rocm", + debug: bool = False, ): """Run develop setup, then build all wheels and install jax""" # Uninstall existing packages first @@ -415,6 +422,7 @@ def build_and_install( fix_bazel_symbols=fix_bazel_symbols, rocm_path=rocm_path, write_makefile=False, # Don't generate Makefile for direct build + debug=debug, ) this_repo_root, xla_path, jax_path = _resolve_relative_paths(xla_dir, jax_dir) @@ -434,7 +442,19 @@ def build_and_install( # Bazel options bazel_options = [f"--override_repository=xla={xla_path}"] - if fix_bazel_symbols: + if debug: + bazel_options.extend( + [ + "--config=debug", + "--compilation_mode=dbg", + "--strip=never", + "--copt=-g3", + "--copt=-O0", + "--cxxopt=-g3", + "--cxxopt=-O0", + ] + ) + elif fix_bazel_symbols: bazel_options.extend(["--strip=never", "--copt=-g3", "--cxxopt=-g3"]) plugin_bazel_options = list(bazel_options) @@ -607,6 +627,11 @@ def parse_args(): "links to corresponding workspaces pointing to bazel's dependencies storage.", action="store_true", ) + common.add_argument( + "--debug", + help="Build in debug mode (unoptimized with full debug symbols).", + action="store_true", + ) common.add_argument( "--rocm-path", help="Location of the ROCm to use for building Jax", @@ -641,6 +666,7 @@ def main(): rebuild_makefile=args.rebuild_makefile, fix_bazel_symbols=args.fix_bazel_symbols, rocm_path=args.rocm_path, + debug=args.debug, ) elif args.action == "build": build_and_install( @@ -651,6 +677,7 @@ def main(): rebuild_makefile=True, fix_bazel_symbols=args.fix_bazel_symbols, rocm_path=args.rocm_path, + debug=args.debug, ) From 124c262b398ccb71333fd81ac9c5f4b649ceb5b1 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 14 Jan 2026 07:53:13 -0600 Subject: [PATCH 05/10] Add --clang-path support to stack.py --- stack.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/stack.py b/stack.py index c7bc19a9d2..5e62ddb4df 100644 --- a/stack.py +++ b/stack.py @@ -292,6 +292,7 @@ def setup_development( rocm_path: str = "/opt/rocm", write_makefile: bool = True, debug: bool = False, + clang_path: str = None, ): """Clone jax and xla repos, and set up Makefile for developers""" @@ -348,7 +349,7 @@ def setup_development( print(f"Warning: using unexpected ROCm version {plugin_namespace_version}") kvs = { - "clang_path": "/usr/lib/llvm-18/bin/clang", + "clang_path": clang_path or "/usr/lib/llvm-18/bin/clang", "plugin_version": plugin_namespace_version, "this_repo_root": this_repo_root, "xla_path": xla_path, @@ -367,12 +368,15 @@ def setup_development( "rocm_path": rocm_path, } - clang_path = find_clang() - if clang_path: - print("Found clang at %r" % clang_path) - kvs["clang_path"] = clang_path + if not clang_path: + clang_discovered = find_clang() + if clang_discovered: + print("Found clang at %r" % clang_discovered) + kvs["clang_path"] = clang_discovered + else: + print("No clang found. Defaulting to %r" % kvs["clang_path"]) else: - print("No clang found. Defaulting to %r" % kvs["clang_path"]) + print("Using provided clang at %r" % clang_path) makefile_content = MAKE_TEMPLATE % kvs @@ -389,6 +393,7 @@ def build_and_install( fix_bazel_symbols: bool = False, rocm_path: str = "/opt/rocm", debug: bool = False, + clang_path: str = None, ): """Run develop setup, then build all wheels and install jax""" # Uninstall existing packages first @@ -423,11 +428,16 @@ def build_and_install( rocm_path=rocm_path, write_makefile=False, # Don't generate Makefile for direct build debug=debug, + clang_path=clang_path, ) this_repo_root, xla_path, jax_path = _resolve_relative_paths(xla_dir, jax_dir) amdgpu_targets = get_amdgpu_targets() - clang_path = find_clang() or "/usr/lib/llvm-18/bin/clang" + + if not clang_path: + clang_path = find_clang() or "/usr/lib/llvm-18/bin/clang" + else: + print("Using provided clang at %r" % clang_path) # ROCm version detection try: @@ -637,6 +647,11 @@ def parse_args(): help="Location of the ROCm to use for building Jax", default="/opt/rocm", ) + common.add_argument( + "--clang-path", + help="Path to the clang compiler to use.", + default=None, + ) subp.add_parser("develop", parents=[common], help="Setup development environment") subp.add_parser( @@ -667,6 +682,7 @@ def main(): fix_bazel_symbols=args.fix_bazel_symbols, rocm_path=args.rocm_path, debug=args.debug, + clang_path=args.clang_path, ) elif args.action == "build": build_and_install( @@ -678,6 +694,7 @@ def main(): fix_bazel_symbols=args.fix_bazel_symbols, rocm_path=args.rocm_path, debug=args.debug, + clang_path=args.clang_path, ) From e332c63f096cd5423710686b6aa4361cd439c19d Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 14 Jan 2026 07:56:42 -0600 Subject: [PATCH 06/10] Prioritize LLVM 18 in find_clang in stack.py --- stack.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/stack.py b/stack.py index 5e62ddb4df..08ada6aa91 100644 --- a/stack.py +++ b/stack.py @@ -137,29 +137,37 @@ def find_clang(): """Find a local clang compiler and return its file path.""" - clang_path = None + # 1. Prioritize specific LLVM versions + llvm_paths = [ + "/usr/lib/llvm-18/bin/clang" + ] + for path in llvm_paths: + if os.path.exists(path): + return path - # check PATH + # 2. Check PATH try: out = subprocess.check_output(["which", "clang"]) - clang_path = out.decode("utf-8").strip() - return clang_path + return out.decode("utf-8").strip() except subprocess.CalledProcessError: pass - # search /usr/lib/ + # 3. search /usr/lib/llvm* directories as a fallback top = "/usr/lib" - for root, dirs, files in os.walk(top): - # only walk llvm dirs - if root == top: - for d in dirs: - if not d.startswith("llvm"): - dirs.remove(d) - - for f in files: - if f == "clang": - clang_path = os.path.join(root, f) - return clang_path + if os.path.exists(top): + for root, dirs, files in os.walk(top): + # only walk llvm dirs + if root == top: + # Prioritize higher versions by sorting reverse + dirs.sort(reverse=True) + for d in list(dirs): + if not d.startswith("llvm"): + dirs.remove(d) + + for f in files: + if f == "clang": + clang_path = os.path.join(root, f) + return clang_path # We didn't find a clang install return None From a4b248fad4aaef9b1180378ccd54d3753c11bb0b Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 14 Jan 2026 18:06:39 -0600 Subject: [PATCH 07/10] Refactor build_and_install and fix python version mismatch in wheels --- stack.py | 106 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 59 insertions(+), 47 deletions(-) diff --git a/stack.py b/stack.py index 08ada6aa91..a1392acf05 100644 --- a/stack.py +++ b/stack.py @@ -392,19 +392,8 @@ def setup_development( mf.write(makefile_content) -def build_and_install( - xla_ref: str, - xla_dir: str, - test_jax_ref: str, - jax_dir: str, - rebuild_makefile: bool = False, - fix_bazel_symbols: bool = False, - rocm_path: str = "/opt/rocm", - debug: bool = False, - clang_path: str = None, -): - """Run develop setup, then build all wheels and install jax""" - # Uninstall existing packages first +def uninstall_jax_packages(): + """Uninstall existing JAX packages.""" print("Uninstalling existing JAX packages...") subprocess.run( ["python3", "-m", "pip", "uninstall", "jax", "-y"], @@ -425,29 +414,9 @@ def build_and_install( check=False, ) - # Setup repos (but don't necessarily generate Makefile) - setup_development( - xla_ref=xla_ref, - xla_dir=xla_dir, - test_jax_ref=test_jax_ref, - jax_dir=jax_dir, - rebuild_makefile=rebuild_makefile, - fix_bazel_symbols=fix_bazel_symbols, - rocm_path=rocm_path, - write_makefile=False, # Don't generate Makefile for direct build - debug=debug, - clang_path=clang_path, - ) - - this_repo_root, xla_path, jax_path = _resolve_relative_paths(xla_dir, jax_dir) - amdgpu_targets = get_amdgpu_targets() - if not clang_path: - clang_path = find_clang() or "/usr/lib/llvm-18/bin/clang" - else: - print("Using provided clang at %r" % clang_path) - - # ROCm version detection +def get_rocm_version(rocm_path): + """Detect ROCm version from path.""" try: version_file = os.path.join(rocm_path, ".info", "version") with open(version_file, encoding="utf-8") as f: @@ -455,10 +424,15 @@ def build_and_install( rocm_version = full_version[0] if rocm_version == "6": rocm_version = "60" + return rocm_version except (OSError, IndexError): - rocm_version = "7" + return "7" + - # Bazel options +def get_build_options( + xla_path, jax_path, debug, fix_bazel_symbols, this_repo_root +): + """Calculate bazel options for build.""" bazel_options = [f"--override_repository=xla={xla_path}"] if debug: bazel_options.extend( @@ -484,6 +458,51 @@ def build_and_install( plugin_bazel_options.append(map_cxxopt) jax_override = f"--override_repository=jax={jax_path}" if jax_path else "" + return bazel_options, plugin_bazel_options, jax_override + + +def build_and_install( + xla_ref: str, + xla_dir: str, + test_jax_ref: str, + jax_dir: str, + rebuild_makefile: bool = False, + fix_bazel_symbols: bool = False, + rocm_path: str = "/opt/rocm", + debug: bool = False, + clang_path: str = None, +): + """Run develop setup, then build all wheels and install jax""" + uninstall_jax_packages() + + # Setup repos (but don't necessarily generate Makefile) + setup_development( + xla_ref=xla_ref, + xla_dir=xla_dir, + test_jax_ref=test_jax_ref, + jax_dir=jax_dir, + rebuild_makefile=rebuild_makefile, + fix_bazel_symbols=fix_bazel_symbols, + rocm_path=rocm_path, + write_makefile=False, # Don't generate Makefile for direct build + debug=debug, + clang_path=clang_path, + ) + + this_repo_root, xla_path, jax_path = _resolve_relative_paths(xla_dir, jax_dir) + amdgpu_targets = get_amdgpu_targets() + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + if not clang_path: + clang_path = find_clang() or "/usr/lib/llvm-18/bin/clang" + else: + print("Using provided clang at %r" % clang_path) + + rocm_version = get_rocm_version(rocm_path) + + bazel_options, plugin_bazel_options, jax_override = get_build_options( + xla_path, jax_path, debug, fix_bazel_symbols, this_repo_root + ) def run_build(cwd, wheels, extra_options=None): cmd = [ @@ -493,6 +512,7 @@ def run_build(cwd, wheels, extra_options=None): "--use_clang=true", f"--wheels={wheels}", "--target_cpu_features=native", + f"--python_version={python_version}", f"--rocm_path={rocm_path}", f"--rocm_version={rocm_version}", f"--rocm_amdgpu_targets={amdgpu_targets}", @@ -523,16 +543,8 @@ def run_build(cwd, wheels, extra_options=None): # 2. Build plugin and pjrt plugin_path = os.path.join(this_repo_root, "jax_rocm_plugin") subprocess.run("rm -rf dist", shell=True, cwd=plugin_path, check=True) - run_build( - plugin_path, - "jax-rocm-plugin", - plugin_bazel_options, - ) - run_build( - plugin_path, - "jax-rocm-pjrt", - plugin_bazel_options, - ) + run_build(plugin_path, "jax-rocm-plugin", plugin_bazel_options) + run_build(plugin_path, "jax-rocm-pjrt", plugin_bazel_options) # 3. Install plugin and pjrt print("Installing jax-rocm-plugin and jax-rocm-pjrt...") From fe9a6d5465db2eadef199f83d13062c06b99401d Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 14 Jan 2026 18:08:43 -0600 Subject: [PATCH 08/10] Add missing sys and re imports to stack.py --- stack.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stack.py b/stack.py index a1392acf05..4c29b2ce8a 100644 --- a/stack.py +++ b/stack.py @@ -3,7 +3,9 @@ import argparse import os +import re import subprocess +import sys TEST_JAX_REPO_REF = "rocm-jaxlib-v0.8.0" From 7638ab24855cf3c3a7965328c6a1041bf48d1815 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 14 Jan 2026 18:13:22 -0600 Subject: [PATCH 09/10] Force system python and ensure gfx950 support in stack.py --- stack.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stack.py b/stack.py index 4c29b2ce8a..ea64c3919f 100644 --- a/stack.py +++ b/stack.py @@ -189,8 +189,8 @@ def get_amdgpu_targets(): pass # Default targets if rocminfo fails targets = ( - "gfx906,gfx908,gfx90a,gfx942,gfx950,gfx1030," - + "gfx1100,gfx1101,gfx1200,gfx1201" + "gfx906,gfx908,gfx90a,gfx942,gfx950,gfx1030," + + "gfx1100,gfx1101,gfx1200,gfx1201" ) return targets @@ -493,7 +493,9 @@ def build_and_install( this_repo_root, xla_path, jax_path = _resolve_relative_paths(xla_dir, jax_dir) amdgpu_targets = get_amdgpu_targets() - python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + # Use "system" to force build to use the container's python version + python_version = "system" if not clang_path: clang_path = find_clang() or "/usr/lib/llvm-18/bin/clang" From db6254a00bc01c4313b8863d1b216471902f8f0a Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 14 Jan 2026 18:14:55 -0600 Subject: [PATCH 10/10] Revert python_version to auto-detection --- stack.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stack.py b/stack.py index ea64c3919f..682e046201 100644 --- a/stack.py +++ b/stack.py @@ -493,9 +493,7 @@ def build_and_install( this_repo_root, xla_path, jax_path = _resolve_relative_paths(xla_dir, jax_dir) amdgpu_targets = get_amdgpu_targets() - - # Use "system" to force build to use the container's python version - python_version = "system" + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" if not clang_path: clang_path = find_clang() or "/usr/lib/llvm-18/bin/clang"