-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup.py
More file actions
103 lines (84 loc) · 3.08 KB
/
setup.py
File metadata and controls
103 lines (84 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import glob
import os
import os.path as osp
import pathlib
import platform
import sys
from setuptools import setup
BUILD_NO_CUDA = os.getenv("BUILD_NO_CUDA", "1") == "1"
WITH_SYMBOLS = os.getenv("WITH_SYMBOLS", "0") == "1"
LINE_INFO = os.getenv("LINE_INFO", "0") == "1"
MAX_JOBS = os.getenv("MAX_JOBS")
need_to_unset_max_jobs = False
if not MAX_JOBS:
need_to_unset_max_jobs = True
os.environ["MAX_JOBS"] = "10"
def get_ext():
from torch.utils.cpp_extension import BuildExtension
return BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=True)
def get_extensions():
import torch
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import CUDAExtension
extensions_dir = osp.join("gsplat", "cuda")
sources = glob.glob(osp.join(extensions_dir, "csrc", "*.cu")) + glob.glob(
osp.join(extensions_dir, "csrc", "*.cpp")
)
sources += [osp.join(extensions_dir, "ext.cpp")]
undef_macros = []
define_macros = []
extra_compile_args = {"cxx": ["-O3"]}
if not os.name == "nt":
extra_compile_args["cxx"] += ["-Wno-sign-compare"]
extra_link_args = [] if WITH_SYMBOLS else ["-s"]
info = parallel_info()
if (
"backend: OpenMP" in info
and "OpenMP not found" not in info
and sys.platform != "darwin"
):
extra_compile_args["cxx"] += ["-DAT_PARALLEL_OPENMP"]
if sys.platform == "win32":
extra_compile_args["cxx"] += ["/openmp"]
else:
extra_compile_args["cxx"] += ["-fopenmp"]
if sys.platform == "darwin" and platform.machine() == "arm64":
extra_compile_args["cxx"] += ["-arch", "arm64"]
extra_link_args += ["-arch", "arm64"]
nvcc_flags = os.getenv("NVCC_FLAGS", "")
nvcc_flags = [] if nvcc_flags == "" else nvcc_flags.split(" ")
nvcc_flags += ["-O3", "--use_fast_math", "-std=c++17"]
if LINE_INFO:
nvcc_flags += ["-lineinfo"]
if torch.version.hip:
define_macros += [("USE_ROCM", None)]
undef_macros += ["__HIP_NO_HALF_CONVERSIONS__"]
else:
nvcc_flags += ["--expt-relaxed-constexpr"]
nvcc_flags += ["-diag-suppress", "20012,186"]
extra_compile_args["nvcc"] = nvcc_flags
if sys.platform == "win32":
extra_compile_args["nvcc"] += [
"-DWIN32_LEAN_AND_MEAN",
"-allow-unsupported-compiler",
]
current_dir = pathlib.Path(__file__).parent.resolve()
glm_path = osp.join(current_dir, "gsplat", "cuda", "csrc", "third_party", "glm")
include_dirs = [glm_path, osp.join(current_dir, "gsplat", "cuda", "include")]
return [
CUDAExtension(
"gsplat.csrc",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
setup(
ext_modules=get_extensions() if not BUILD_NO_CUDA else [],
cmdclass={"build_ext": get_ext()} if not BUILD_NO_CUDA else {},
)
if need_to_unset_max_jobs:
os.environ.pop("MAX_JOBS")