File tree Expand file tree Collapse file tree 1 file changed +30
-0
lines changed
Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Original file line number Diff line number Diff line change 1+ #! /bin/bash
2+ set -euxo pipefail
3+
4+ # Get versions from environment variables, with defaults
5+ cuda_version=" ${CUDA_VERSION:- 12.1} "
6+ pytorch_version=" ${PYTORCH_VERSION:- latest} "
7+
8+ # Determine CUDA short version for wheel index
9+ case " $cuda_version " in
10+ " 11.8" ) cuda_short=" cu118" ;;
11+ " 12.1" ) cuda_short=" cu121" ;;
12+ " 12.4" ) cuda_short=" cu124" ;;
13+ " 12.6" ) cuda_short=" cu126" ;;
14+ " 12.8" ) cuda_short=" cu128" ;;
15+ * )
16+ echo " Error: Unsupported CUDA version: $cuda_version "
17+ exit 1
18+ ;;
19+ esac
20+
21+ index_url=" https://download.pytorch.org/whl/$cuda_short "
22+ echo " PyTorch wheel index: $index_url "
23+
24+ if [ " $pytorch_version " = " latest" ]; then
25+ echo " Installing latest PyTorch for CUDA $cuda_version "
26+ pip install torch torchvision --index-url " $index_url "
27+ else
28+ echo " Installing PyTorch $pytorch_version for CUDA $cuda_version "
29+ pip install torch==" $pytorch_version " torchvision --index-url " $index_url "
30+ fi
You can’t perform that action at this time.
0 commit comments