Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion tests/test_installer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import pytest
from unittest.mock import patch
from torchruntime.installer import get_install_commands, get_pip_commands, run_commands
from torchruntime.installer import get_install_commands, get_pip_commands, run_commands, install


def test_empty_args():
Expand Down Expand Up @@ -125,3 +125,57 @@ def test_run_commands():
# Check that subprocess.run was called with the correct arguments
mock_run.assert_any_call(cmds[0])
mock_run.assert_any_call(cmds[1])


def test_install_promotes_cuda_platform_for_torch_27(monkeypatch):
captured = {}

def fake_get_install_commands(torch_platform, packages):
captured["platform"] = torch_platform
return [packages]

monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: [])
monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu124")
monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands)
monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds)
monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None)

install(["torch==2.7.1"])

assert captured["platform"] == "cu128"


def test_install_demotes_cuda_platform_for_torch_26(monkeypatch):
captured = {}

def fake_get_install_commands(torch_platform, packages):
captured["platform"] = torch_platform
return [packages]

monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: [])
monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu128")
monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands)
monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds)
monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None)

install(["torch==2.6.0"])

assert captured["platform"] == "cu124"


def test_install_promotes_cuda_platform_for_torchvision_022(monkeypatch):
captured = {}

def fake_get_install_commands(torch_platform, packages):
captured["platform"] = torch_platform
return [packages]

monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: [])
monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu124")
monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands)
monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds)
monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None)

install(["torchvision==0.22.0"])

assert captured["platform"] == "cu128"
128 changes: 128 additions & 0 deletions torchruntime/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

from .consts import CONTACT_LINK
from .device_db import get_gpus
from .gpu_db import get_nvidia_arch
from .platform_detection import get_torch_platform

os_name = platform.system()

PIP_PREFIX = [sys.executable, "-m", "pip", "install"]
CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$")
ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$")
REQ_SPEC_REGEX = re.compile(
r"^\s*(?P<name>[A-Za-z0-9_.-]+)(?:\[[^\]]+\])?\s*(?P<op>==|>=|<=|~=|!=|<|>)\s*(?P<version>[^,;\s]+)"
)
MAJOR_MINOR_REGEX = re.compile(r"^(?P<major>\d+)\.(?P<minor>\d+)")
TORCH_2_7 = (2, 7)


def get_install_commands(torch_platform, packages):
Expand Down Expand Up @@ -91,6 +97,127 @@ def run_commands(cmds):
subprocess.run(cmd)


def _parse_major_minor(version: str):
match = MAJOR_MINOR_REGEX.match(version)
if not match:
return None
return int(match.group("major")), int(match.group("minor"))


def _is_major_minor_gte(left, right):
return left[0] > right[0] or (left[0] == right[0] and left[1] >= right[1])


def _is_major_minor_lt(left, right):
return left[0] < right[0] or (left[0] == right[0] and left[1] < right[1])


def _cuda_platform_has_prefix(torch_platform: str):
return torch_platform.startswith("nightly/")


def _cuda_platform_with_prefix(torch_platform: str, cuda_platform: str):
if _cuda_platform_has_prefix(torch_platform):
return f"nightly/{cuda_platform}"
return cuda_platform


def _get_cuda_platform_for_pytorch_packages(packages):
"""
Infer a CUDA platform (cu124 vs cu128) from user-specified PyTorch package versions.

This is needed because PyTorch 2.7.x is published under cu128 wheels, and older
releases (<=2.6) are published under cu124 wheels. When the requested versions
are pinned, the installer must select the matching index URL or pip will fail
with "No matching distribution found".

Returns:
"cu124" | "cu128" | None
"""

if not packages:
return None

desired_cuda = None

for raw_req in packages:
if not raw_req:
continue

req = str(raw_req).strip()
if not req or req.startswith("-"):
continue

match = REQ_SPEC_REGEX.match(req)
if not match:
continue

name = match.group("name").lower()
op = match.group("op")
version = match.group("version")

major_minor = _parse_major_minor(version)
if not major_minor:
continue

# Map torchvision's versioning scheme to the matching torch major/minor.
if name == "torchvision":
tv_major, tv_minor = major_minor
if tv_major != 0:
continue
torch_major_minor = (2, max(0, tv_minor - 15))
elif name in ("torch", "torchaudio"):
torch_major_minor = major_minor
else:
continue

required_cuda = None
if op == "==":
required_cuda = "cu128" if _is_major_minor_gte(torch_major_minor, TORCH_2_7) else "cu124"
elif op in (">=", ">", "~="):
if _is_major_minor_gte(torch_major_minor, TORCH_2_7):
required_cuda = "cu128"
elif op in ("<", "<="):
if _is_major_minor_lt(torch_major_minor, TORCH_2_7):
required_cuda = "cu124"

if required_cuda is None:
continue

if desired_cuda is None:
desired_cuda = required_cuda
elif desired_cuda != required_cuda:
# Conflicting version pins, leave platform unchanged and let pip resolve/fail.
return None

return desired_cuda


def _maybe_override_nvidia_cuda_platform(torch_platform, packages, gpu_infos):
"""
Adjust cu124/cu128 index selection based on pinned torch/torchvision/torchaudio versions.
"""
if not torch_platform or not CUDA_REGEX.match(torch_platform):
return torch_platform

desired_cuda = _get_cuda_platform_for_pytorch_packages(packages)
if desired_cuda not in ("cu124", "cu128"):
return torch_platform

current_cuda = torch_platform.split("/", 1)[-1]
if current_cuda == desired_cuda:
return torch_platform

# Do not demote Blackwell GPUs from cu128 -> cu124; older torch versions won't support them anyway.
if current_cuda == "cu128" and desired_cuda == "cu124":
device_names = set(gpu.device_name for gpu in (gpu_infos or []))
arch_version = get_nvidia_arch(device_names) if device_names else 0
if arch_version == 12:
return torch_platform

return _cuda_platform_with_prefix(torch_platform, desired_cuda)


def install(packages=[], use_uv=False):
"""
packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"].
Expand All @@ -99,6 +226,7 @@ def install(packages=[], use_uv=False):

gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
torch_platform = _maybe_override_nvidia_cuda_platform(torch_platform, packages, gpu_infos)
cmds = get_install_commands(torch_platform, packages)
cmds = get_pip_commands(cmds, use_uv=use_uv)
run_commands(cmds)