diff --git a/setup.py b/setup.py index 2e59f85..a8e3ff2 100644 --- a/setup.py +++ b/setup.py @@ -8,17 +8,24 @@ root_dir = os.path.dirname(os.path.abspath(__file__)) - -mpi_home = os.environ.get("MPI_HOME") -if mpi_home is None: - mpi_home = "/usr/lib/openmpi/" -if not os.path.exists(mpi_home): - mpi_home = "/usr/lib/x86_64-linux-gnu/openmpi/" -if not os.path.exists(mpi_home): - print("Couldn't find MPI install dir, please set MPI_HOME env variable") +# Sensible defaults. +mpi_lib_path = os.environ.get("MPI_LIB_PATH", "/usr/lib/openmpi") +mpi_inc_path = os.environ.get("MPI_INC_PATH", "/usr/include/openmpi") +mpi_home = os.environ.get("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi/") + +# If MPI_HOME is valid, derive inc/lib. +if os.path.exists(mpi_home): + mpi_lib_path = os.path.join(mpi_home, "lib") + mpi_inc_path = os.path.join(mpi_home, "include") + +if not (os.path.exists(mpi_lib_path) and os.path.exists(mpi_inc_path)): + logging.warn(mpi_lib_path) + logging.warn(mpi_inc_path) + print("Couldn't find MPI install dir, please set MPI_HOME env variable or " + "set MPI_LIB_PATH and MPI_INC_PATH separately for include files " + "and library files") sys.exit(1) - nccl_home = os.environ.get("NCCL_HOME") if nccl_home is None or not os.path.exists(nccl_home): nccl_home = None @@ -40,11 +47,11 @@ "src/ProcessGroupMPI.cpp", ], include_dirs=[ - os.path.join(root_dir, "include"), - os.path.join(mpi_home, "include"), + os.path.join(root_dir, "include"), + mpi_inc_path, ], library_dirs=[ - os.path.join(mpi_home, "lib"), + mpi_lib_path, ], libraries=["mpi",], extra_compile_args=["-DOMPI_SKIP_MPICXX=1"] + torch_version_defines, @@ -86,6 +93,7 @@ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "License :: OSI Approved :: BSD License", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Operating System :: OS Independent",