Skip to content

Commit 2fbee44

Browse files
committed
Add augment to version string and allow install path overriding with NVTE_INSTALL_PATH
1 parent ec678af commit 2fbee44

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def te_version() -> str:
2727
Includes Git commit as local version, unless suppressed with
2828
NVTE_NO_LOCAL_VERSION environment variable.
2929
30+
[augment] We replace the git hash with the string "augment"
31+
to identify this as a custom augment release.
32+
3033
"""
3134
with open(root_path / "VERSION", "r") as f:
3235
version = f.readline().strip()
@@ -43,7 +46,7 @@ def te_version() -> str:
4346
pass
4447
else:
4548
commit = output.stdout.strip()
46-
version += f"+{commit}"
49+
version += "+augment"
4750
return version
4851

4952
@lru_cache(maxsize=1)
@@ -290,6 +293,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:
290293

291294
# Framework-specific requirements
292295
if "pytorch" in frameworks():
296+
# [augment] We remove the versioning requirement on flash-attn
293297
add_unique(install_reqs, ["torch", "flash-attn"])
294298
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
295299
if "jax" in frameworks():

transformer_engine/common/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
def get_te_path():
1414
"""Find Transformer Engine install path using pip"""
1515

16+
# [augment] Allow overriding install path with an envvar. Should be something like:
17+
# > NVTE_INSTALL_PATH=/opt/conda/lib/python3.9/site-packages
18+
if "NVTE_INSTALL_PATH" in os.environ:
19+
return os.environ["NVTE_INSTALL_PATH"]
20+
1621
command = [sys.executable, "-m", "pip", "show", "transformer_engine"]
1722
result = subprocess.run(command, capture_output=True, check=True, text=True)
1823
result = result.stdout.replace("\n", ":").split(":")

0 commit comments

Comments
 (0)