Skip to content

Commit 80a2d68

Browse files
committed
Update versioning to automatically include cuda and torch version
1 parent f47af0d commit 80a2d68

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

setup.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ctypes
66
from functools import lru_cache
77
import os
8+
from packaging.version import parse
89
from pathlib import Path
910
import re
1011
import shutil
@@ -33,6 +34,17 @@ def te_version() -> str:
3334
"""
3435
with open(root_path / "VERSION", "r") as f:
3536
version = f.readline().strip()
37+
38+
# [augment] Here is where we replace the git hash with our own versioning.
39+
# You can disable this behavior with NVTE_NO_AUGMENT_VERSION=1.
40+
if not int(os.getenv("NVTE_NO_AUGMENT_VERSION", "0")):
41+
# NOTE: we are assuming you are building for pytorch. TE cannot make this assumption in general.
42+
import torch
43+
torch_version = parse(torch.__version__)
44+
cuda_version = parse(torch.version.cuda)
45+
version_string = f".cu{cuda_version.major}{cuda_version.minor}.torch{torch_version.major}{torch_version.minor}"
46+
return version + "+augment" + version_string
47+
3648
if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")):
3749
try:
3850
output = subprocess.run(
@@ -46,7 +58,7 @@ def te_version() -> str:
4658
pass
4759
else:
4860
commit = output.stdout.strip()
49-
version += "+augment"
61+
version += f"+{commit}"
5062
return version
5163

5264
@lru_cache(maxsize=1)

0 commit comments

Comments
 (0)