diff --git a/src/aedifix/packages/cuda.py b/src/aedifix/packages/cuda.py index 687e33e..adf215b 100644 --- a/src/aedifix/packages/cuda.py +++ b/src/aedifix/packages/cuda.py @@ -27,6 +27,20 @@ def _guess_cuda_compiler() -> str | None: return None +def _guess_cuda_architecture() -> str: + try: + return os.environ["CUDAARCHS"] + except KeyError: + pass + + try: + return os.environ["CMAKE_CUDA_ARCHITECTURES"] + except KeyError: + pass + + return "all-major" + + class CudaArchAction(Action): @staticmethod def map_cuda_arch_names(in_arch: str) -> list[str]: @@ -117,7 +131,7 @@ class CUDA(Package): spec=ArgSpec( dest="cuda_arch", required=False, - default=["all-major"], + default=_guess_cuda_architecture(), action=CudaArchAction, help=( "Specify the target GPU architecture. Available choices are: " diff --git a/tests/packages/test_cuda.py b/tests/packages/test_cuda.py index c247601..afc4b39 100644 --- a/tests/packages/test_cuda.py +++ b/tests/packages/test_cuda.py @@ -7,7 +7,7 @@ import pytest -from aedifix.packages.cuda import CudaArchAction +from aedifix.packages.cuda import CudaArchAction, _guess_cuda_architecture ARCH_STR: tuple[tuple[str, list[str]], ...] = ( ("", []), @@ -28,5 +28,27 @@ def test_map_cuda_arch_names(self, argv: str, expected: list[str]) -> None: assert ret == expected +class TestCUDA: + @pytest.mark.parametrize( + ("env_var", "env_value", "expected"), + [ + ("CUDAARCHS", "volta", "volta"), + ("CMAKE_CUDA_ARCHITECTURES", "75", "75"), + ("", "", "all-major"), + ], + ) + def test_default_cuda_arches( + self, + monkeypatch: pytest.MonkeyPatch, + env_var: str, + env_value: str, + expected: str, + ) -> None: + if env_var: + monkeypatch.setenv(env_var, env_value) + + assert _guess_cuda_architecture() == expected + + if __name__ == "__main__": sys.exit(pytest.main())