Skip to content

Commit daae258

Browse files
authored
Create install_pytorch.sh
1 parent ecad2ac commit daae258

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

.github/tools/install_pytorch.sh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
set -euxo pipefail
3+
4+
# Get versions from environment variables, with defaults
5+
cuda_version="${CUDA_VERSION:-12.1}"
6+
pytorch_version="${PYTORCH_VERSION:-latest}"
7+
8+
# Determine CUDA short version for wheel index
9+
case "$cuda_version" in
10+
"11.8") cuda_short="cu118";;
11+
"12.1") cuda_short="cu121";;
12+
"12.4") cuda_short="cu124";;
13+
"12.6") cuda_short="cu126";;
14+
"12.8") cuda_short="cu128";;
15+
*)
16+
echo "Error: Unsupported CUDA version: $cuda_version"
17+
exit 1
18+
;;
19+
esac
20+
21+
index_url="https://download.pytorch.org/whl/$cuda_short"
22+
echo "PyTorch wheel index: $index_url"
23+
24+
if [ "$pytorch_version" = "latest" ]; then
25+
echo "Installing latest PyTorch for CUDA $cuda_version"
26+
pip install torch torchvision --index-url "$index_url"
27+
else
28+
echo "Installing PyTorch $pytorch_version for CUDA $cuda_version"
29+
pip install torch=="$pytorch_version" torchvision --index-url "$index_url"
30+
fi

0 commit comments

Comments
 (0)