-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathstack.py
More file actions
473 lines (386 loc) · 15.7 KB
/
stack.py
File metadata and controls
473 lines (386 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
#!/usr/bin/env python3
"""Script for setting up local development environments"""
import argparse
import os
import subprocess
TEST_JAX_REPO_REF = "rocm-jaxlib-v0.8.2"
XLA_REPO_REF = "rocm-jaxlib-v0.8.2"
JAX_REPL_URL = "https://github.com/rocm/jax"
XLA_REPL_URL = "https://github.com/rocm/xla"
DEFAULT_XLA_DIR = "../xla"
DEFAULT_KERNELS_JAX_DIR = "../jax"
MAKE_TEMPLATE = r"""
# gfx targets for which XLA and jax custom call kernels are built for
# AMDGPU_TARGETS ?= "gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
# customize to a single arch for local dev builds to reduce compile time
AMDGPU_TARGETS ?= "$(shell rocminfo | grep -o -m 1 'gfx.*')"
###### auxiliary vars. Note the absence of quotes around variable values, - these vars are expected to be put into other quoted vars
# Bazel options to build repos in a certain mode.
CFG_DEBUG=--config=debug --compilation_mode=dbg --strip=never --copt=-g3 --copt=-O0 --cxxopt=-g3 --cxxopt=-O0
CFG_RELEASE_WITH_SYM=--strip=never --copt=-g3 --cxxopt=-g3
# Sets '-fdebug-prefix-map=' compiler parameter to remap source file locations from bazel's reproducible builds
# sandbox /proc/self/cwd to correct local paths. Note, external dependencies support require 'external' symlink
# in a corresponding bazel workspace root
PLUGIN_SYMBOLS=--copt=-fdebug-prefix-map=/proc/self/cwd=%(this_repo_root)s/jax_rocm_plugin --cxxopt=-fdebug-prefix-map=/proc/self/cwd=%(this_repo_root)s/jax_rocm_plugin
JAXLIB_SYMBOLS=--copt=-fdebug-prefix-map=/proc/self/cwd=%(kernels_jax_path)s --cxxopt=-fdebug-prefix-map=/proc/self/cwd=%(kernels_jax_path)s
###### --bazel_options values, must be enquoted
# Defines a value for '--bazel_options' for each of 3 build types (pjrt, plugin + jaxlib).
# By default, uses local XLA for each wheel. Redefine to whatever option is needed for your case
ALL_BAZEL_OPTIONS="--override_repository=xla=%(xla_path)s%(custom_options)s"
# PLUGIN_BAZEL_OPTIONS and JAXLIB_BAZEL_OPTIONS define pjrt&plugin specific bazel options and jaxlib specific build options.
PLUGIN_BAZEL_OPTIONS="%(plugin_bazel_options)s"
JAXLIB_BAZEL_OPTIONS="%(jaxlib_bazel_options)s"
# Use your local JAX for building the kernels in jax_rocm_plugin
# KERNELS_JAX_OVERRIDE_OPTION="--override_repository=jax=../jax"
KERNELS_JAX_OVERRIDE_OPTION="%(kernels_jax_override)s"
###
.PHONY: test clean install dist
.default: dist
dist: jax_rocm_plugin jax_rocm_pjrt
jax_rocm_plugin:
python3 ./build/build.py build \
--use_clang=true \
--wheels=jax-rocm-plugin \
--target_cpu_features=native \
--rocm_path=%(rocm_path)s \
--rocm_version=%(plugin_version)s \
--rocm_amdgpu_targets=${AMDGPU_TARGETS} \
--bazel_options=${ALL_BAZEL_OPTIONS} \
--bazel_options=${PLUGIN_BAZEL_OPTIONS} \
--bazel_options=${KERNELS_JAX_OVERRIDE_OPTION} \
--verbose \
--clang_path=%(clang_path)s
jax_rocm_pjrt:
python3 ./build/build.py build \
--use_clang=true \
--wheels=jax-rocm-pjrt \
--target_cpu_features=native \
--rocm_path=%(rocm_path)s \
--rocm_version=%(plugin_version)s \
--rocm_amdgpu_targets=${AMDGPU_TARGETS} \
--bazel_options=${ALL_BAZEL_OPTIONS} \
--bazel_options=${PLUGIN_BAZEL_OPTIONS} \
--bazel_options=${KERNELS_JAX_OVERRIDE_OPTION} \
--verbose \
--clang_path=%(clang_path)s
clean:
rm -rf dist
install: dist
pip install --force-reinstall dist/*
refresh: clean dist install
test:
python3 tests/test_plugin.py
# Sometimes developers might want to build their own jaxlib. Usually, we can
# just use the one from upstream, but we might want to build our own if we
# suspect that jaxlib isn't loading the plugin properly or if ROCm-specific
# code is somehow making its way into jaxlib.
jaxlib:
(cd %(kernels_jax_path)s && python3 ./build/build.py build \
--target_cpu_features=native \
--use_clang=true \
--clang_path=%(clang_path)s \
--wheels=jaxlib \
--bazel_options=${ALL_BAZEL_OPTIONS} \
--bazel_options=${JAXLIB_BAZEL_OPTIONS} \
--verbose \
)
jaxlib_clean:
rm -f %(kernels_jax_path)s/dist/*
jaxlib_install:
pip install --force-reinstall %(kernels_jax_path)s/dist/*
refresh_jaxlib: jaxlib_clean jaxlib jaxlib_install
"""
def find_clang():
"""Find a local clang compiler and return its file path."""
clang_path = None
# check PATH
try:
out = subprocess.check_output(["which", "clang"])
clang_path = out.decode("utf-8").strip()
return clang_path
except subprocess.CalledProcessError:
pass
# search /usr/lib/
top = "/usr/lib"
for root, dirs, files in os.walk(top):
# only walk llvm dirs
if root == top:
for d in dirs:
if not d.startswith("llvm"):
dirs.remove(d)
for f in files:
if f == "clang":
clang_path = os.path.join(root, f)
return clang_path
# We didn't find a clang install
return None
def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]:
"""Transforms relative to absolute paths. This is needed to properly support
symbolic information remapping"""
this_repo_root = os.path.dirname(os.path.realpath(__file__))
xla_path = (
xla_dir
if os.path.isabs(xla_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{xla_dir}")
)
assert os.path.isdir(
xla_path
), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'"
if kernels_jax_dir:
kernels_jax_path = (
kernels_jax_dir
if os.path.isabs(kernels_jax_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}")
)
# pylint: disable=line-too-long
assert os.path.isdir(
kernels_jax_path
), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'"
else:
kernels_jax_path = None
return this_repo_root, xla_path, kernels_jax_path
def _add_externals_symlink(this_repo_root: str, xla_path: str, kernels_jax_path: str):
"""Adds ./external symlink to $(bazel info output_base)/external into each path"""
assert os.path.isabs(this_repo_root) and os.path.isabs(xla_path)
assert not kernels_jax_path or os.path.isabs(kernels_jax_path)
# checking 'bazel' is executable. We only support essentially bazelisk here.
# Supporting individual bazel binaries installed by the upstream build system
# when it can't find bazel is a TODO for the future.
# Broad exceptions aren't a problem here
# pylint: disable=broad-exception-caught
try:
v = (
subprocess.run(
["bazel", "--version"],
cwd=f"{this_repo_root}/jax_rocm_plugin",
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)
.stdout.decode("utf-8")
.rstrip()
)
print(
f"Bazelisk is detected (bazel=={v}), proceeding with creation of symlinks"
)
except Exception as e:
print(
"WARNING: Bazelisk is NOT detected and a wrapper for specific bazel "
"versions isn't implemented. Symlinks to '$(bazel info output_base)/external' "
"will not be created in each bazel workspace root, you'll have to make them manually.\n"
f"The error was: {e}"
)
return
def _link(target: str, name: str):
if os.path.exists(name):
print(f"Filesystem object {name} exists, skipping symlink creation.")
else:
os.symlink(target, name, target_is_directory=True)
print(f"Created symlink '{name}'-->'{target}'")
def _make_external(wrkspace: str):
try:
output_base = (
subprocess.run(
["bazel", "info", "output_base"],
cwd=wrkspace,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)
.stdout.decode("utf-8")
.rstrip()
)
except Exception as e:
print(f"Failed to query 'bazel info output_base' for '{wrkspace}':{e}")
return
_link(f"{output_base}/external", f"{wrkspace}/external")
_make_external(f"{this_repo_root}/jax_rocm_plugin")
_make_external(xla_path) # not necessary, but useful for work on XLA only
if kernels_jax_path:
_make_external(kernels_jax_path)
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
def setup_development(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
kernels_jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
rocm_path: str = "/opt/rocm",
):
"""Clone jax and xla repos, and set up Makefile for developers"""
# Always clone the JAX repo that we'll use for running unit tests
if not os.path.exists("./jax"):
cmd = ["git", "clone"]
cmd.extend(["--branch", test_jax_ref])
cmd.append(JAX_REPL_URL)
subprocess.check_call(cmd)
# clone xla from source for building jax_rocm_plugin if the user didn't
# specify an existing XLA directory
if not os.path.exists("./xla") and xla_dir == DEFAULT_XLA_DIR:
cmd = ["git", "clone"]
cmd.extend(["--branch", xla_ref])
cmd.append(XLA_REPL_URL)
subprocess.check_call(cmd)
# create build/install/test script
makefile_path = "./jax_rocm_plugin/Makefile"
if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols:
this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths(
xla_dir, kernels_jax_dir
)
if fix_bazel_symbols:
plugin_bazel_options = "${PLUGIN_SYMBOLS}"
jaxlib_bazel_options = "${JAXLIB_SYMBOLS}"
custom_options = " ${CFG_RELEASE_WITH_SYM}"
_add_externals_symlink(this_repo_root, xla_path, kernels_jax_path)
else: # not modifying the build unless asked
plugin_bazel_options, jaxlib_bazel_options, custom_options = "", "", ""
# try to detect the namespace version from the ROCm version
# this is expected to throw an exception if the specified ROCm path is invalid, for example
# if there is no .info/version
with open(
os.path.join(rocm_path, ".info", "version"), encoding="utf-8"
) as versionfile:
full_version = versionfile.readline()
plugin_namespace_version = full_version[0]
if plugin_namespace_version == "6":
# note the inconsistency in numbering - ROCm 6 is "60" but ROCm 7 is "7"
plugin_namespace_version = "60"
elif plugin_namespace_version != "7":
# assume that other versions will be one digit like 7
print(f"Warning: using unexpected ROCm version {plugin_namespace_version}")
kvs = {
"clang_path": "/usr/lib/llvm-18/bin/clang",
"plugin_version": plugin_namespace_version,
"this_repo_root": this_repo_root,
"xla_path": xla_path,
"kernels_jax_path": kernels_jax_path,
"plugin_bazel_options": plugin_bazel_options,
"jaxlib_bazel_options": jaxlib_bazel_options,
"custom_options": custom_options,
# If the user wants to use their own JAX for building the plugin wheel
# that contains all the jaxlib kernel code (jax_rocm7_plugin), add that
# to the Makefile.
"kernels_jax_override": (
("--override_repository=jax=%s" % kernels_jax_path)
if kernels_jax_path
else ""
),
"rocm_path": rocm_path,
}
clang_path = find_clang()
if clang_path:
print("Found clang at %r" % clang_path)
kvs["clang_path"] = clang_path
else:
print("No clang found. Defaulting to %r" % kvs["clang_path"])
makefile_content = MAKE_TEMPLATE % kvs
with open(makefile_path, "w", encoding="utf-8") as mf:
mf.write(makefile_content)
def dev_docker(rm):
"""Start a docker container for local plugin development"""
cur_abs_path = os.path.abspath(os.curdir)
image_name = "ubuntu:24.04"
cmd = [
"docker",
"run",
"-it",
"--network=host",
"--device=/dev/kfd",
"--device=/dev/dri",
"--ipc=host",
"--shm-size=16G",
"--group-add",
"video",
"--cap-add=SYS_PTRACE",
"--security-opt",
"seccomp=unconfined",
"-v",
"%s:/rocm-jax" % cur_abs_path,
"--env=ROCM_JAX_DIR=/rocm-jax",
]
if rm:
cmd.append("--rm")
cmd.append(image_name)
with subprocess.Popen(cmd) as p:
p.wait()
# build mode setup
# install jax/jaxlib from known versions
# setup build/install/test script
def setup_build():
"""Setup for building the plugin locally"""
raise NotImplementedError
def parse_args():
"""Parse command line arguments"""
p = argparse.ArgumentParser()
subp = p.add_subparsers(dest="action", required=True)
dev = subp.add_parser("develop")
dev.add_argument(
"--rebuild-makefile",
help="Force rebuild of Makefile from template.",
action="store_true",
)
dev.add_argument(
"--xla-ref",
help="XLA commit reference to checkout on clone",
default=XLA_REPO_REF,
)
dev.add_argument(
"--xla-dir",
help=(
"Set the XLA path in the Makefile. This must either be a path "
"relative to jax_rocm_plugin or an absolute path."
),
default=DEFAULT_XLA_DIR,
)
dev.add_argument(
"--jax-ref",
help="JAX commit reference to checkout on clone",
default=TEST_JAX_REPO_REF,
)
dev.add_argument(
"--kernel-jax-dir",
help=(
"If you want to use a local JAX directory for building the "
"plugin kernels wheel (jax_rocm7_plugin), the path to the "
"directory of repo. Defaults to %s" % DEFAULT_KERNELS_JAX_DIR
),
default=DEFAULT_KERNELS_JAX_DIR,
)
dev.add_argument(
"--fix-bazel-symbols",
help="When this option is enabled, the script assumes you need to build "
"code in a release with symbolic info configuration to alleviate debugging. "
"The script enables respective bazel options and adds 'external' symbolic "
"links to corresponding workspaces pointing to bazel's dependencies storage.",
action="store_true",
)
dev.add_argument(
"--rocm-path",
help="Location of the ROCm to use for building Jax",
default="/opt/rocm",
)
doc_parser = subp.add_parser("docker")
doc_parser.add_argument(
"--rm",
help="Remove the dev docker container after it exits",
action="store_true",
)
return p.parse_args()
def main():
"""Run commands depending on command line input"""
args = parse_args()
if args.action == "docker":
dev_docker(rm=args.rm)
elif args.action == "develop":
setup_development(
xla_ref=args.xla_ref,
xla_dir=args.xla_dir,
test_jax_ref=args.jax_ref,
kernels_jax_dir=args.kernel_jax_dir,
rebuild_makefile=args.rebuild_makefile,
fix_bazel_symbols=args.fix_bazel_symbols,
rocm_path=args.rocm_path,
)
if __name__ == "__main__":
main()