forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_torch_version.py
More file actions
162 lines (145 loc) · 6.14 KB
/
generate_torch_version.py
File metadata and controls
162 lines (145 loc) · 6.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import annotations
import argparse
import email
import os
import re
import subprocess
from pathlib import Path
from packaging.version import Version
from setuptools import distutils # type: ignore[import,attr-defined]
UNKNOWN = "Unknown"
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
def get_sha(pytorch_root: str | Path) -> str:
try:
rev = None
if os.path.exists(os.path.join(pytorch_root, ".git")):
rev = subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=pytorch_root
)
elif os.path.exists(os.path.join(pytorch_root, ".hg")):
rev = subprocess.check_output(
["hg", "identify", "-r", "."], cwd=pytorch_root
)
if rev:
return rev.decode("ascii").strip()
except Exception:
pass
return UNKNOWN
def get_tag(pytorch_root: str | Path) -> str:
try:
tag = subprocess.run(
["git", "describe", "--tags", "--exact"],
cwd=pytorch_root,
encoding="ascii",
capture_output=True,
).stdout.strip()
if RELEASE_PATTERN.match(tag):
return tag
else:
return UNKNOWN
except Exception:
return UNKNOWN
def get_torch_version(sha: str | None = None) -> str:
"""Determine the torch version string.
The version is determined from one of the following sources, in order of
precedence:
1. The PYTORCH_BUILD_VERSION and PYTORCH_BUILD_NUMBER environment variables.
These are set by the PyTorch build system when building official
releases. If built from an sdist, it is checked that the version matches
the sdist version.
2. The PKG-INFO file, if it exists. This file is included in source
distributions (sdist) and contains the version of the sdist.
3. The version.txt file, which contains the base version string. If the git
commit SHA is available, it is appended to the version string to
indicate that this is a development build.
"""
pytorch_root = Path(__file__).absolute().parent.parent
pkg_info_path = pytorch_root / "PKG-INFO"
if pkg_info_path.exists():
with open(pkg_info_path) as f:
pkg_info = email.message_from_file(f)
sdist_version = pkg_info["Version"]
else:
sdist_version = None
if os.getenv("PYTORCH_BUILD_VERSION"):
if os.getenv("PYTORCH_BUILD_NUMBER") is None:
raise AssertionError(
"PYTORCH_BUILD_NUMBER must be set when PYTORCH_BUILD_VERSION is set"
)
build_number = int(os.getenv("PYTORCH_BUILD_NUMBER", ""))
version = os.getenv("PYTORCH_BUILD_VERSION", "")
if build_number > 1:
version += ".post" + str(build_number)
origin = "PYTORCH_BUILD_{VERSION,NUMBER} env variables"
elif sdist_version:
version = sdist_version
origin = "PKG-INFO"
else:
version = Path(pytorch_root / "version.txt").read_text().strip()
origin = "version.txt"
if sdist_version is None and sha != UNKNOWN:
if sha is None:
sha = get_sha(pytorch_root)
version += "+git" + sha[:7]
origin += " and git commit"
# Validate that the version is PEP 440 compliant
parsed_version = Version(version)
if sdist_version:
if (l := parsed_version.local) and l.startswith("git"):
# Assume local version is git<sha> and
# hence whole version is source version
source_version = version
else:
# local version is absent or platform tag
source_version = version.partition("+")[0]
if sdist_version != source_version:
raise AssertionError(
f"Source part '{source_version}' of version '{version}' from "
f"{origin} does not match version '{sdist_version}' from PKG-INFO"
)
return version
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate torch/version.py from build and environment metadata."
)
parser.add_argument(
"--is-debug",
"--is_debug",
type=distutils.util.strtobool,
help="Whether this build is debug mode or not.",
)
parser.add_argument("--cuda-version", "--cuda_version", type=str)
parser.add_argument("--hip-version", "--hip_version", type=str)
parser.add_argument("--rocm-version", "--rocm_version", type=str)
parser.add_argument("--xpu-version", "--xpu_version", type=str)
args = parser.parse_args()
if args.is_debug is None:
raise AssertionError("is_debug argument must be provided")
args.cuda_version = None if args.cuda_version == "" else args.cuda_version
args.hip_version = None if args.hip_version == "" else args.hip_version
args.rocm_version = None if args.rocm_version == "" else args.rocm_version
args.xpu_version = None if args.xpu_version == "" else args.xpu_version
pytorch_root = Path(__file__).parent.parent
version_path = pytorch_root / "torch" / "version.py"
# Attempt to get tag first, fall back to sha if a tag was not found
tagged_version = get_tag(pytorch_root)
sha = get_sha(pytorch_root)
if tagged_version == UNKNOWN:
version = get_torch_version(sha)
else:
version = tagged_version
with open(version_path, "w") as f:
f.write("from typing import Optional\n\n")
f.write(
"__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'rocm', 'xpu']\n"
)
f.write(f"__version__ = '{version}'\n")
# NB: This is not 100% accurate, because you could have built the
# library code with DEBUG, but csrc without DEBUG (in which case
# this would claim to be a release build when it's not.)
f.write(f"debug = {repr(bool(args.is_debug))}\n")
f.write(f"cuda: Optional[str] = {repr(args.cuda_version)}\n")
f.write(f"git_version = {repr(sha)}\n")
f.write(f"hip: Optional[str] = {repr(args.hip_version)}\n")
f.write(f"rocm: Optional[str] = {repr(args.rocm_version)}\n")
f.write(f"xpu: Optional[str] = {repr(args.xpu_version)}\n")