From 5acb900097a4c8fab6d0ae67f95abd9261a32390 Mon Sep 17 00:00:00 2001 From: James Nye Date: Wed, 25 Mar 2026 21:28:51 -0700 Subject: [PATCH 1/7] =?UTF-8?q?feat:=20AMD=20ROCm=20GPU=20support=20?= =?UTF-8?q?=E2=80=94=20device=20enumeration,=20compile=20mode,=20SDPA=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for AMD GPUs (RDNA3/RDNA4) via ROCm's HIP translation layer. Most PyTorch code works unchanged since HIP maps torch.cuda.* APIs, but three areas needed explicit AMD handling: device_utils.py: - Add GPUInfo dataclass and enumerate_gpus() function - Tries nvidia-smi, then amd-smi (ROCm 6.0+), then rocm-smi (legacy), then torch.cuda fallback — works on NVIDIA, AMD, or either inference_engine.py: - torch.compile uses "max-autotune-no-cudagraphs" on ROCm to avoid a known HIP graph segfault on large graphs (pytorch/pytorch#155720) - Auto-sets TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 on ROCm so SDPA dispatches to AOTriton flash attention on RDNA3 (without this it silently falls back to O(n^2) math backend) - Skips torch.compile entirely on Windows ROCm where Triton kernel compilation hangs indefinitely (eager fallback works fine) pyproject.toml: - Add "rocm" optional extra with install instructions README.md: - Add AMD ROCm to Hardware Requirements - Add ROCm install instructions for Linux and Windows - Add dedicated AMD ROCm Setup section with supported GPUs, automatic behavior, and known limitations Tested on RX 7800 XT (gfx1101) Windows with AMD ROCm 7.2: - torch.cuda.is_available() = True - SDPA with float16 = OK - Raw inference = OK - torch.compile = hangs on Windows (skipped), works on Linux --- CorridorKeyModule/inference_engine.py | 84 +++++++++++--- README.md | 79 ++++++++++++- corridorkey_cli.py | 49 ++++++-- device_utils.py | 159 ++++++++++++++++++++++++++ pyproject.toml | 4 + 5 files changed, 348 insertions(+), 27 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 3f43588d..ec6f0b8d 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -5,16 +5,42 @@ import os import sys -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -import torchvision -import torchvision.transforms.v2 as T -import torchvision.transforms.v2.functional as TF - -from .core import color_utils as cu -from .core.model_transformer import GreenFormer +# ROCm: must be set before importing torch so the CUDA allocator picks them up. +# Detection: /opt/rocm (Linux), HIP_PATH (Windows default C:\hip), or explicit opt-in. +_is_rocm_system = ( + os.path.exists("/opt/rocm") + or os.environ.get("HIP_PATH") is not None + or os.environ.get("HIP_VISIBLE_DEVICES") is not None + or os.environ.get("CORRIDORKEY_ROCM") == "1" +) +if _is_rocm_system: + os.environ.setdefault("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") + os.environ.setdefault("MIOPEN_FIND_MODE", "2") + os.environ.setdefault("MIOPEN_LOG_LEVEL", "0") + # Enable GTT (system RAM as GPU overflow) on Linux for 16GB cards. + # pytorch-rocm-gtt must be installed separately: pip install pytorch-rocm-gtt + try: + import pytorch_rocm_gtt + + pytorch_rocm_gtt.patch() + except ImportError: + pass + +# Persist torch.compile autotune cache across runs (default is /tmp which +# gets wiped on reboot — saves 10-20 min re-autotuning on ROCm, ~30s on CUDA) +_inductor_cache = os.path.join(os.path.expanduser("~"), ".cache", "corridorkey", "inductor") +os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", _inductor_cache) + +import cv2 # noqa: E402 +import numpy as np # noqa: E402 +import torch # noqa: E402 +import torch.nn.functional as F # noqa: E402 +import torchvision # noqa: E402 +import torchvision.transforms.v2 as T # noqa: E402 +import torchvision.transforms.v2.functional as TF # noqa: E402 + +from .core import color_utils as cu # noqa: E402 +from .core.model_transformer import GreenFormer # noqa: E402 logger = logging.getLogger(__name__) @@ -52,8 +78,15 @@ def __init__( self.model = self._load_model() - # We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution. - if sys.platform == "linux" or sys.platform == "win32": + is_rocm = hasattr(torch.version, "hip") and torch.version.hip + + # torch.compile is tested on CUDA (Windows + Linux) and ROCm (Linux). + # ROCm on Windows hangs during Triton kernel compilation — skip it. + # CORRIDORKEY_SKIP_COMPILE=1 forces eager mode (useful for testing). + skip_compile = (is_rocm and sys.platform == "win32") or os.environ.get("CORRIDORKEY_SKIP_COMPILE") == "1" + if skip_compile: + logger.info("Skipping torch.compile (eager mode)") + elif sys.platform == "linux" or sys.platform == "win32": self._compile() def _load_model(self) -> GreenFormer: @@ -116,20 +149,41 @@ def _load_model(self) -> GreenFormer: return model def _compile(self): + is_rocm = hasattr(torch.version, "hip") and torch.version.hip + if is_rocm: + # "default" avoids the heavy autotuning that OOM-kills 16GB cards + # at 2048x2048. Still compiles Triton kernels, just skips the + # exhaustive benchmarking. HIP graphs are also avoided (segfault + # on large graphs — pytorch/pytorch#155720). + compile_mode = "default" + else: + compile_mode = "max-autotune" + try: - compiled_model = torch.compile(self.model, mode="max-autotune") - # Trigger compilation with a dummy input + logger.info( + "Compiling model (mode=%s) — this may take 10-20 minutes on first run. " + "Compiled kernels are cached for future runs.", + compile_mode, + ) + compiled_model = torch.compile(self.model, mode=compile_mode) + # Trigger compilation with a dummy input (the actual compile + # happens here, not in the torch.compile() call above) dummy_input = torch.zeros( 1, 4, self.img_size, self.img_size, dtype=self.model_precision, device=self.device ) with torch.inference_mode(): compiled_model(dummy_input) + del dummy_input + if torch.cuda.is_available(): + torch.cuda.empty_cache() self.model = compiled_model + logger.info("Model compiled successfully (mode=%s)", compile_mode) except Exception as e: logger.info(f"Compilation error: {e}") logger.warning("Model compilation failed. Falling back to eager mode.") - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def _preprocess_input( self, image_batch: torch.Tensor, mask_batch_linear: torch.Tensor, input_is_linear: bool diff --git a/README.md b/README.md index b42eed54..07f2788e 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,8 @@ This project was designed and built on a Linux workstation (Puget Systems PC) eq The most recent build should work on computers with 6-8 gig of VRAM, and it can run on most M1+ Mac systems with unified memory. Yes, it might even work on your old Macbook pro. Let us know on the Discord! -* **Windows Users:** To run GPU acceleration natively on Windows, your system MUST have NVIDIA drivers that support **CUDA 12.8 or higher** installed. If your drivers only support older CUDA versions, the installer will likely fallback to the CPU. +* **Windows Users (NVIDIA):** To run GPU acceleration natively on Windows, your system MUST have NVIDIA drivers that support **CUDA 12.8 or higher** installed. If your drivers only support older CUDA versions, the installer will likely fallback to the CPU. +* **AMD GPU Users (ROCm):** AMD Radeon RX 7000 series (RDNA3) and RX 9000 series (RDNA4) are supported via ROCm on **Linux**. Windows ROCm support is experimental (torch.compile is not yet functional). See the [AMD ROCm Setup](#amd-rocm-setup) section below. * **GVM (Optional):** Requires approximately **80 GB of VRAM** and utilizes massive Stable Video Diffusion models. * **VideoMaMa (Optional):** Natively requires a massive chunk of VRAM as well (originally 80GB+). While the community has tweaked the architecture to run at less than 24GB, those extreme memory optimizations have not yet been fully implemented in this repository. * **BiRefNet (Optional):** Lightweight AlphaHint generator option. @@ -72,6 +73,7 @@ This project uses **[uv](https://docs.astral.sh/uv/)** to manage Python and all uv sync --extra cuda # CUDA GPU acceleration (Linux/Windows) uv sync --extra mlx # Apple Silicon MLX acceleration ``` + For **AMD ROCm** setup, see the [AMD ROCm Setup](#amd-rocm-setup) section below. 4. **Download the Models:** * **CorridorKey v1.0 Model (~300MB):** Downloads automatically on first run. If no `.pth` file is found in `CorridorKeyModule/checkpoints/`, the engine fetches it from [CorridorKey's HuggingFace](https://huggingface.co/nikopueringer/CorridorKey_v1.0) and saves it as `CorridorKey.pth`. No manual download needed. * **GVM Weights (Optional):** [HuggingFace: geyongtao/gvm](https://huggingface.co/geyongtao/gvm) @@ -220,6 +222,81 @@ uv run python corridorkey_cli.py wizard --win_path "/path/to/clips" **Use native MLX instead of PyTorch MPS:** MLX avoids PyTorch's MPS layer entirely and typically runs faster on Apple Silicon. See the [Backend Selection](#backend-selection) section below for setup steps. +### AMD ROCm Setup + +CorridorKey supports AMD GPUs via PyTorch's ROCm/HIP backend. The `torch.cuda.*` API works transparently on AMD — HIP intercepts all CUDA calls at runtime, so the inference code runs unchanged. + +**Supported GPUs (ROCm 7.2+):** +- RX 7900 XTX (24GB) / XT (20GB) / GRE (16GB) — RDNA3, gfx1100 +- RX 7800 XT (16GB) / 7700 XT (12GB) — RDNA3, gfx1101 +- RX 9070 XT / 9070 (16GB) — RDNA4, gfx1201 + +**VRAM requirements:** CorridorKey inference at 2048x2048 needs ~18GB VRAM. The RX 7900 XTX (24GB) and RX 7900 XT (20GB) run at full resolution. Cards with 16GB (RX 7800 XT, 9070 XT) work on Windows (which uses system RAM as overflow) but may OOM on Linux — see notes below. + +**Linux native (recommended):** +```bash +# Install AMD's ROCm torch wheels, then sync everything else +pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/rocm6.3 +uv sync + +# Verify +uv run python -c "import torch; print(torch.cuda.is_available(), torch.cuda.get_device_name(0))" +``` + +**WSL2 (Windows Subsystem for Linux):** + +Requires AMD Adrenalin 26.1.1+ driver on Windows. Install ROCm inside WSL2, then use AMD's WSL-specific torch wheels: + +```bash +# 1. Install ROCm for WSL (Ubuntu 24.04) +sudo apt update +wget https://repo.radeon.com/amdgpu-install/7.2/ubuntu/noble/amdgpu-install_7.2.70200-1_all.deb +sudo apt install ./amdgpu-install_7.2.70200-1_all.deb +amdgpu-install -y --usecase=wsl,rocm --no-dkms + +# 2. Verify GPU is visible +rocminfo # should show your AMD GPU + +# 3. Install AMD's WSL torch wheels (Python 3.12) +pip3 install \ + https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp312-cp312-linux_x86_64.whl \ + https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp312-cp312-linux_x86_64.whl \ + https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp312-cp312-linux_x86_64.whl + +# 4. Fix WSL runtime library conflict (required) +location=$(pip3 show torch | grep Location | awk -F ": " '{print $2}') +rm -f ${location}/torch/lib/libhsa-runtime64.so* + +# 5. Install CorridorKey deps AFTER torch (so pip doesn't overwrite ROCm torch) +pip3 install -e . +``` + +**Windows native (experimental):** + +Windows ROCm requires Python 3.12 and AMD Adrenalin 25.3.1+ driver. `torch.compile` does not work on Windows ROCm — inference runs in eager mode (significantly slower than Linux). + +```powershell +py -3.12 -m pip install https://repo.radeon.com/rocm/windows/rocm-rel-7.2/rocm-7.2.0.dev0-py3-none-win_amd64.whl +py -3.12 -m pip install --no-cache-dir https://repo.radeon.com/rocm/windows/rocm-rel-7.2/torch-2.9.1+rocmsdk20260116-cp312-cp312-win_amd64.whl https://repo.radeon.com/rocm/windows/rocm-rel-7.2/torchvision-0.24.1+rocmsdk20260116-cp312-cp312-win_amd64.whl +``` + +**What CorridorKey does automatically on ROCm:** +- Sets `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` so SDPA dispatches to flash attention kernels on RDNA3 (without this, attention falls back to a slow O(n²) path) +- Sets `MIOPEN_FIND_MODE=2` for faster convolution kernel selection (reduces warmup from 5-8 minutes to seconds) +- Uses `torch.compile(mode="default")` on Linux to avoid OOM during kernel autotuning on 16GB cards +- Skips `torch.compile` entirely on Windows ROCm where Triton compilation hangs +- Auto-detects ROCm via `/opt/rocm` (Linux), `HIP_PATH` (Windows), or `CORRIDORKEY_ROCM=1` env var (explicit opt-in) + +**First-run note:** The first inference run on a new AMD GPU triggers Triton kernel autotuning (10-20 minutes). This is cached in `~/.cache/corridorkey/inductor/` and only happens once per GPU architecture. Subsequent runs start instantly. + +**16GB cards on Linux:** CorridorKey at 2048x2048 needs ~18GB. Windows handles this transparently via shared GPU memory (system RAM overflow). On Linux, the GPU has a hard VRAM limit. If you hit OOM on a 16GB card, install `pytorch-rocm-gtt` to enable GTT (system RAM as GPU overflow) — CorridorKey detects and uses it automatically: +```bash +pip install pytorch-rocm-gtt +``` +GTT memory is accessed over PCIe (~10-20x slower than VRAM), so expect slower frame times on 16GB cards vs 20-24GB cards. + +**WSL2 limitation:** WSL2 cannot use GTT or shared memory — it has a hard VRAM limit. 16GB cards will OOM in WSL2 at 2048x2048. Use Windows native instead, or a card with 20GB+ VRAM. + ## Backend Selection CorridorKey supports two inference backends: diff --git a/corridorkey_cli.py b/corridorkey_cli.py index b040b3fe..41bdaffc 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -19,17 +19,44 @@ import shutil import sys import warnings -from typing import Annotated, Optional -import typer -from rich.console import Console -from rich.logging import RichHandler -from rich.panel import Panel -from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TaskID, TextColumn, TimeElapsedColumn -from rich.prompt import Confirm, IntPrompt, Prompt -from rich.table import Table +# ROCm: must be set before any torch import (including transitive via diffusers/GVM) +_is_rocm = ( + os.path.exists("/opt/rocm") + or os.environ.get("HIP_PATH") is not None + or os.environ.get("HIP_VISIBLE_DEVICES") is not None + or os.environ.get("CORRIDORKEY_ROCM") == "1" +) +if _is_rocm: + os.environ.setdefault("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") + os.environ.setdefault("MIOPEN_FIND_MODE", "2") + os.environ.setdefault("MIOPEN_LOG_LEVEL", "0") + try: + import pytorch_rocm_gtt + + pytorch_rocm_gtt.patch() + except ImportError: + pass + +from typing import Annotated, Optional # noqa: E402 + +import typer # noqa: E402 +from rich.console import Console # noqa: E402 +from rich.logging import RichHandler # noqa: E402 +from rich.panel import Panel # noqa: E402 +from rich.progress import ( # noqa: E402 + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskID, + TextColumn, + TimeElapsedColumn, +) +from rich.prompt import Confirm, IntPrompt, Prompt # noqa: E402 +from rich.table import Table # noqa: E402 -from clip_manager import ( +from clip_manager import ( # noqa: E402 LINUX_MOUNT_ROOT, ClipEntry, InferenceSettings, @@ -43,8 +70,8 @@ run_videomama, scan_clips, ) -from CorridorKeyModule.backend import resolve_backend -from device_utils import resolve_device +from CorridorKeyModule.backend import resolve_backend # noqa: E402 +from device_utils import resolve_device # noqa: E402 logger = logging.getLogger(__name__) console = Console() diff --git a/device_utils.py b/device_utils.py index 6894d082..8e12f5e3 100644 --- a/device_utils.py +++ b/device_utils.py @@ -2,6 +2,8 @@ import logging import os +import subprocess +from dataclasses import dataclass import torch @@ -67,6 +69,163 @@ def resolve_device(requested: str | None = None) -> str: return device +@dataclass +class GPUInfo: + """Information about a single GPU.""" + + index: int + name: str + vram_total_gb: float + vram_free_gb: float + + +def _enumerate_nvidia() -> list[GPUInfo] | None: + """Enumerate NVIDIA GPUs via nvidia-smi. Returns None if unavailable.""" + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=index,name,memory.total,memory.free", + "--format=csv,nounits,noheader", + ], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode != 0: + return None + gpus: list[GPUInfo] = [] + for line in result.stdout.strip().split("\n"): + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 4: + gpus.append( + GPUInfo( + index=int(parts[0]), + name=parts[1], + vram_total_gb=float(parts[2]) / 1024, + vram_free_gb=float(parts[3]) / 1024, + ) + ) + return gpus + except (FileNotFoundError, subprocess.TimeoutExpired): + return None + + +def _enumerate_amd() -> list[GPUInfo] | None: + """Enumerate AMD GPUs via amd-smi (ROCm). Returns None if unavailable. + + Tries amd-smi first (modern), then rocm-smi (legacy). + """ + # Try amd-smi (ROCm 6.0+) + try: + import json as _json + + result = subprocess.run( + ["amd-smi", "static", "--json"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + data = _json.loads(result.stdout) + gpus: list[GPUInfo] = [] + for i, gpu in enumerate(data): + name = gpu.get("asic", {}).get("market_name", f"AMD GPU {i}") + vram_info = gpu.get("vram", {}) + total_mb = vram_info.get("size", {}).get("value", 0) + total_gb = float(total_mb) / 1024 if total_mb else 0 + gpus.append(GPUInfo(index=i, name=name, vram_total_gb=total_gb, vram_free_gb=total_gb)) + if gpus: + # Try to get live VRAM usage from monitor + try: + mon = subprocess.run( + ["amd-smi", "monitor", "--vram", "--json"], + capture_output=True, + text=True, + timeout=5, + ) + if mon.returncode == 0: + mon_data = _json.loads(mon.stdout) + for entry in mon_data: + idx = entry.get("gpu", 0) + used_pct = entry.get("vram_use", 0) + if idx < len(gpus) and gpus[idx].vram_total_gb > 0: + used_gb = gpus[idx].vram_total_gb * float(used_pct) / 100 + gpus[idx].vram_free_gb = gpus[idx].vram_total_gb - used_gb + except Exception: + pass + return gpus + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + # Fallback: rocm-smi (legacy, deprecated but still ships) + try: + result = subprocess.run( + ["rocm-smi", "--showid", "--showmeminfo", "vram", "--csv"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0 and result.stdout.strip(): + gpus = [] + for line in result.stdout.strip().split("\n")[1:]: # skip header + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 3: + idx = int(parts[0]) if parts[0].isdigit() else len(gpus) + total_b = int(parts[1]) if parts[1].isdigit() else 0 + used_b = int(parts[2]) if parts[2].isdigit() else 0 + gpus.append( + GPUInfo( + index=idx, + name=f"AMD GPU {idx}", + vram_total_gb=total_b / (1024**3), + vram_free_gb=(total_b - used_b) / (1024**3), + ) + ) + if gpus: + return gpus + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + return None + + +def enumerate_gpus() -> list[GPUInfo]: + """List all available GPUs with VRAM info. + + Tries nvidia-smi (NVIDIA), then amd-smi/rocm-smi (AMD ROCm), + then falls back to torch.cuda API. + Returns an empty list on non-GPU systems. + """ + # NVIDIA + gpus = _enumerate_nvidia() + if gpus is not None: + return gpus + + # AMD ROCm + gpus = _enumerate_amd() + if gpus is not None: + return gpus + + # Fallback to torch (works for both NVIDIA and ROCm via HIP) + if torch.cuda.is_available(): + fallback: list[GPUInfo] = [] + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + total = props.total_memory / (1024**3) + fallback.append( + GPUInfo( + index=i, + name=props.name, + vram_total_gb=total, + vram_free_gb=total, # can't query free without setting device + ) + ) + return fallback + + return [] + + def clear_device_cache(device: torch.device | str) -> None: """Clear GPU memory cache if applicable (no-op for CPU).""" device_type = device.type if isinstance(device, torch.device) else device diff --git a/pyproject.toml b/pyproject.toml index f4f76372..fcd3fca5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,10 @@ cuda = [ mlx = [ "corridorkey-mlx ; python_version >= '3.11'", ] +# ROCm (AMD GPU): torch must be installed manually from the ROCm wheel index +# before running uv sync — see README.md "AMD ROCm Setup" for instructions. +# A rocm extra is not provided because the ROCm wheel metadata conflicts +# with PyPI's nvidia-* packages, making uv sync --extra rocm unreliable. [dependency-groups] dev = ["pytest", "pytest-cov", "ruff", "hypothesis"] From a15655ae1d013064f701c643c50f54fa165dec34 Mon Sep 17 00:00:00 2001 From: James Nye Date: Thu, 26 Mar 2026 11:09:40 -0700 Subject: [PATCH 2/7] =?UTF-8?q?refactor:=20address=20PR=20review=20?= =?UTF-8?q?=E2=80=94=20deduplicate=20ROCm=20detection,=20harden=20parsing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract is_rocm_system() and setup_rocm_env() into device_utils.py as shared functions. Both corridorkey_cli.py and inference_engine.py now call setup_rocm_env() instead of duplicating the detection + env var logic. - device_utils.py restructured so setup_rocm_env() is defined before `import torch`, allowing callers to set env vars before torch init. - Broaden pytorch_rocm_gtt exception handler from ImportError to Exception (catches patch() runtime errors too, not just missing package). - Add per-entry error handling in _enumerate_amd() — malformed amd-smi JSON entries are skipped instead of crashing the whole enumeration. Add json.JSONDecodeError to the outer exception handler. - Use top-level `import json` instead of local `import json as _json`. - Change MIOPEN_LOG_LEVEL from 0 (silent) to 4 (suppress info/debug, keep warnings and errors visible). - Compile message is now conditional: "10-20 minutes on first run (ROCm)" vs just "Compiling model..." on NVIDIA where it takes ~30 seconds. - README VRAM claim clarified: ~10GB on NVIDIA, ~18GB on AMD due to HIP allocator overhead. Prevents scaring off 12GB NVIDIA card owners. --- CorridorKeyModule/inference_engine.py | 37 ++++++---------- README.md | 2 +- corridorkey_cli.py | 19 ++------ device_utils.py | 62 +++++++++++++++++++++------ 4 files changed, 66 insertions(+), 54 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index ec6f0b8d..4e2e91e8 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -5,26 +5,10 @@ import os import sys -# ROCm: must be set before importing torch so the CUDA allocator picks them up. -# Detection: /opt/rocm (Linux), HIP_PATH (Windows default C:\hip), or explicit opt-in. -_is_rocm_system = ( - os.path.exists("/opt/rocm") - or os.environ.get("HIP_PATH") is not None - or os.environ.get("HIP_VISIBLE_DEVICES") is not None - or os.environ.get("CORRIDORKEY_ROCM") == "1" -) -if _is_rocm_system: - os.environ.setdefault("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") - os.environ.setdefault("MIOPEN_FIND_MODE", "2") - os.environ.setdefault("MIOPEN_LOG_LEVEL", "0") - # Enable GTT (system RAM as GPU overflow) on Linux for 16GB cards. - # pytorch-rocm-gtt must be installed separately: pip install pytorch-rocm-gtt - try: - import pytorch_rocm_gtt - - pytorch_rocm_gtt.patch() - except ImportError: - pass +# ROCm env vars + autotune cache must be set before importing torch. +from device_utils import setup_rocm_env as _setup_rocm_env # noqa: E402 — no torch import + +_setup_rocm_env() # Persist torch.compile autotune cache across runs (default is /tmp which # gets wiped on reboot — saves 10-20 min re-autotuning on ROCm, ~30s on CUDA) @@ -160,11 +144,14 @@ def _compile(self): compile_mode = "max-autotune" try: - logger.info( - "Compiling model (mode=%s) — this may take 10-20 minutes on first run. " - "Compiled kernels are cached for future runs.", - compile_mode, - ) + if is_rocm: + logger.info( + "Compiling model (mode=%s) — this may take 10-20 minutes on first run (ROCm). " + "Compiled kernels are cached for future runs.", + compile_mode, + ) + else: + logger.info("Compiling model (mode=%s)...", compile_mode) compiled_model = torch.compile(self.model, mode=compile_mode) # Trigger compilation with a dummy input (the actual compile # happens here, not in the torch.compile() call above) diff --git a/README.md b/README.md index 07f2788e..b9ee3800 100644 --- a/README.md +++ b/README.md @@ -231,7 +231,7 @@ CorridorKey supports AMD GPUs via PyTorch's ROCm/HIP backend. The `torch.cuda.*` - RX 7800 XT (16GB) / 7700 XT (12GB) — RDNA3, gfx1101 - RX 9070 XT / 9070 (16GB) — RDNA4, gfx1201 -**VRAM requirements:** CorridorKey inference at 2048x2048 needs ~18GB VRAM. The RX 7900 XTX (24GB) and RX 7900 XT (20GB) run at full resolution. Cards with 16GB (RX 7800 XT, 9070 XT) work on Windows (which uses system RAM as overflow) but may OOM on Linux — see notes below. +**VRAM requirements:** CorridorKey inference at 2048x2048 uses ~10GB on NVIDIA but ~18GB on AMD due to HIP allocator overhead. The RX 7900 XTX (24GB) and RX 7900 XT (20GB) run at full resolution. Cards with 16GB (RX 7800 XT, 9070 XT) work on Windows (which uses system RAM as overflow) but may OOM on Linux — see notes below. **Linux native (recommended):** ```bash diff --git a/corridorkey_cli.py b/corridorkey_cli.py index 41bdaffc..3665261f 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -20,23 +20,10 @@ import sys import warnings -# ROCm: must be set before any torch import (including transitive via diffusers/GVM) -_is_rocm = ( - os.path.exists("/opt/rocm") - or os.environ.get("HIP_PATH") is not None - or os.environ.get("HIP_VISIBLE_DEVICES") is not None - or os.environ.get("CORRIDORKEY_ROCM") == "1" -) -if _is_rocm: - os.environ.setdefault("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") - os.environ.setdefault("MIOPEN_FIND_MODE", "2") - os.environ.setdefault("MIOPEN_LOG_LEVEL", "0") - try: - import pytorch_rocm_gtt +# ROCm env vars must be set before any torch import (including transitive via diffusers/GVM) +from device_utils import setup_rocm_env # noqa: E402 — no torch import in this path - pytorch_rocm_gtt.patch() - except ImportError: - pass +setup_rocm_env() from typing import Annotated, Optional # noqa: E402 diff --git a/device_utils.py b/device_utils.py index 8e12f5e3..ce3a4646 100644 --- a/device_utils.py +++ b/device_utils.py @@ -1,18 +1,55 @@ """Centralized cross-platform device selection for CorridorKey.""" +import json import logging import os import subprocess from dataclasses import dataclass -import torch - logger = logging.getLogger(__name__) DEVICE_ENV_VAR = "CORRIDORKEY_DEVICE" VALID_DEVICES = ("auto", "cuda", "mps", "cpu") +def is_rocm_system() -> bool: + """Detect if the system has AMD ROCm available (before or after torch import). + + Checks: /opt/rocm (Linux), HIP_PATH (Windows, default C:\\hip), + HIP_VISIBLE_DEVICES (any platform), CORRIDORKEY_ROCM=1 (explicit opt-in). + """ + return ( + os.path.exists("/opt/rocm") + or os.environ.get("HIP_PATH") is not None + or os.environ.get("HIP_VISIBLE_DEVICES") is not None + or os.environ.get("CORRIDORKEY_ROCM") == "1" + ) + + +def setup_rocm_env() -> None: + """Set ROCm environment variables and apply optional patches. + + Must be called before importing torch. Safe to call on non-ROCm systems (no-op). + """ + if not is_rocm_system(): + return + os.environ.setdefault("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") + os.environ.setdefault("MIOPEN_FIND_MODE", "2") + # Level 4 = suppress info/debug but keep warnings and errors visible + os.environ.setdefault("MIOPEN_LOG_LEVEL", "4") + # Enable GTT (system RAM as GPU overflow) on Linux for 16GB cards. + # pytorch-rocm-gtt must be installed separately: pip install pytorch-rocm-gtt + try: + import pytorch_rocm_gtt + + pytorch_rocm_gtt.patch() + except Exception: + pass # not installed, or patch failed — non-fatal + + +import torch # noqa: E402 — deferred so setup_rocm_env() can run first + + def detect_best_device() -> str: """Auto-detect best available device: CUDA > MPS > CPU.""" if torch.cuda.is_available(): @@ -118,8 +155,6 @@ def _enumerate_amd() -> list[GPUInfo] | None: """ # Try amd-smi (ROCm 6.0+) try: - import json as _json - result = subprocess.run( ["amd-smi", "static", "--json"], capture_output=True, @@ -127,14 +162,17 @@ def _enumerate_amd() -> list[GPUInfo] | None: timeout=10, ) if result.returncode == 0: - data = _json.loads(result.stdout) + data = json.loads(result.stdout) gpus: list[GPUInfo] = [] for i, gpu in enumerate(data): - name = gpu.get("asic", {}).get("market_name", f"AMD GPU {i}") - vram_info = gpu.get("vram", {}) - total_mb = vram_info.get("size", {}).get("value", 0) - total_gb = float(total_mb) / 1024 if total_mb else 0 - gpus.append(GPUInfo(index=i, name=name, vram_total_gb=total_gb, vram_free_gb=total_gb)) + try: + name = gpu.get("asic", {}).get("market_name", f"AMD GPU {i}") + vram_info = gpu.get("vram", {}) + total_mb = vram_info.get("size", {}).get("value", 0) + total_gb = float(total_mb) / 1024 if total_mb else 0 + gpus.append(GPUInfo(index=i, name=name, vram_total_gb=total_gb, vram_free_gb=total_gb)) + except (KeyError, TypeError, ValueError): + logger.debug("Failed to parse amd-smi entry %d, skipping", i) if gpus: # Try to get live VRAM usage from monitor try: @@ -145,7 +183,7 @@ def _enumerate_amd() -> list[GPUInfo] | None: timeout=5, ) if mon.returncode == 0: - mon_data = _json.loads(mon.stdout) + mon_data = json.loads(mon.stdout) for entry in mon_data: idx = entry.get("gpu", 0) used_pct = entry.get("vram_use", 0) @@ -155,7 +193,7 @@ def _enumerate_amd() -> list[GPUInfo] | None: except Exception: pass return gpus - except (FileNotFoundError, subprocess.TimeoutExpired): + except (FileNotFoundError, subprocess.TimeoutExpired, json.JSONDecodeError): pass # Fallback: rocm-smi (legacy, deprecated but still ships) From a9a49fe525e4ff106be7e34eaa117acdbadd6a55 Mon Sep 17 00:00:00 2001 From: James Nye Date: Thu, 26 Mar 2026 12:49:56 -0700 Subject: [PATCH 3/7] =?UTF-8?q?cleanup:=20address=20self-review=20?= =?UTF-8?q?=E2=80=94=20remove=20dead=20code,=20fix=20comments,=20improve?= =?UTF-8?q?=20error=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Remove dead code: GPUInfo, enumerate_gpus(), _enumerate_nvidia(), _enumerate_amd() — 158 lines with zero callers in upstream. These are cloud-platform code that doesn't belong in this PR. 2. Remove unused imports (json, subprocess, dataclass) that were only needed by the dead GPU enumeration code. 3. Fix misleading "deferred so setup_rocm_env() can run first" comment. The env vars are read at operation time by PyTorch/MIOpen C libraries, not at torch import time. The code works correctly either way. 4. Remove all noqa: E402 comments — imports are now in normal order since the pre-import ordering was unnecessary. 5. Better exception handling in setup_rocm_env(): ImportError caught separately (expected, silent) vs other exceptions (logged as warning with traceback for debugging broken pytorch-rocm-gtt installs). 6. Cache is_rocm as self._is_rocm instance attribute instead of recomputing hasattr(torch.version, "hip") in both __init__ and _compile. --- CorridorKeyModule/inference_engine.py | 39 +++--- corridorkey_cli.py | 31 +++-- device_utils.py | 173 ++------------------------ 3 files changed, 41 insertions(+), 202 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index 4e2e91e8..ad079ebd 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -5,27 +5,28 @@ import os import sys -# ROCm env vars + autotune cache must be set before importing torch. -from device_utils import setup_rocm_env as _setup_rocm_env # noqa: E402 — no torch import +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms.v2 as T +import torchvision.transforms.v2.functional as TF -_setup_rocm_env() +from device_utils import setup_rocm_env + +from .core import color_utils as cu +from .core.model_transformer import GreenFormer + +# ROCm env vars are read at operation time (not import time), so this is +# fine after importing torch. Also sets up pytorch-rocm-gtt if installed. +setup_rocm_env() # Persist torch.compile autotune cache across runs (default is /tmp which # gets wiped on reboot — saves 10-20 min re-autotuning on ROCm, ~30s on CUDA) _inductor_cache = os.path.join(os.path.expanduser("~"), ".cache", "corridorkey", "inductor") os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", _inductor_cache) -import cv2 # noqa: E402 -import numpy as np # noqa: E402 -import torch # noqa: E402 -import torch.nn.functional as F # noqa: E402 -import torchvision # noqa: E402 -import torchvision.transforms.v2 as T # noqa: E402 -import torchvision.transforms.v2.functional as TF # noqa: E402 - -from .core import color_utils as cu # noqa: E402 -from .core.model_transformer import GreenFormer # noqa: E402 - logger = logging.getLogger(__name__) @@ -60,14 +61,13 @@ def __init__( self.model_precision = model_precision + self._is_rocm = hasattr(torch.version, "hip") and torch.version.hip self.model = self._load_model() - is_rocm = hasattr(torch.version, "hip") and torch.version.hip - # torch.compile is tested on CUDA (Windows + Linux) and ROCm (Linux). # ROCm on Windows hangs during Triton kernel compilation — skip it. # CORRIDORKEY_SKIP_COMPILE=1 forces eager mode (useful for testing). - skip_compile = (is_rocm and sys.platform == "win32") or os.environ.get("CORRIDORKEY_SKIP_COMPILE") == "1" + skip_compile = (self._is_rocm and sys.platform == "win32") or os.environ.get("CORRIDORKEY_SKIP_COMPILE") == "1" if skip_compile: logger.info("Skipping torch.compile (eager mode)") elif sys.platform == "linux" or sys.platform == "win32": @@ -133,8 +133,7 @@ def _load_model(self) -> GreenFormer: return model def _compile(self): - is_rocm = hasattr(torch.version, "hip") and torch.version.hip - if is_rocm: + if self._is_rocm: # "default" avoids the heavy autotuning that OOM-kills 16GB cards # at 2048x2048. Still compiles Triton kernels, just skips the # exhaustive benchmarking. HIP graphs are also avoided (segfault @@ -144,7 +143,7 @@ def _compile(self): compile_mode = "max-autotune" try: - if is_rocm: + if self._is_rocm: logger.info( "Compiling model (mode=%s) — this may take 10-20 minutes on first run (ROCm). " "Compiled kernels are cached for future runs.", diff --git a/corridorkey_cli.py b/corridorkey_cli.py index 3665261f..b3e991a4 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -19,19 +19,13 @@ import shutil import sys import warnings +from typing import Annotated, Optional -# ROCm env vars must be set before any torch import (including transitive via diffusers/GVM) -from device_utils import setup_rocm_env # noqa: E402 — no torch import in this path - -setup_rocm_env() - -from typing import Annotated, Optional # noqa: E402 - -import typer # noqa: E402 -from rich.console import Console # noqa: E402 -from rich.logging import RichHandler # noqa: E402 -from rich.panel import Panel # noqa: E402 -from rich.progress import ( # noqa: E402 +import typer +from rich.console import Console +from rich.logging import RichHandler +from rich.panel import Panel +from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, @@ -40,10 +34,10 @@ TextColumn, TimeElapsedColumn, ) -from rich.prompt import Confirm, IntPrompt, Prompt # noqa: E402 -from rich.table import Table # noqa: E402 +from rich.prompt import Confirm, IntPrompt, Prompt +from rich.table import Table -from clip_manager import ( # noqa: E402 +from clip_manager import ( LINUX_MOUNT_ROOT, ClipEntry, InferenceSettings, @@ -57,8 +51,11 @@ run_videomama, scan_clips, ) -from CorridorKeyModule.backend import resolve_backend # noqa: E402 -from device_utils import resolve_device # noqa: E402 +from CorridorKeyModule.backend import resolve_backend +from device_utils import resolve_device, setup_rocm_env + +# ROCm env vars are read at operation time, so this is fine after imports. +setup_rocm_env() logger = logging.getLogger(__name__) console = Console() diff --git a/device_utils.py b/device_utils.py index ce3a4646..6778c230 100644 --- a/device_utils.py +++ b/device_utils.py @@ -1,10 +1,7 @@ """Centralized cross-platform device selection for CorridorKey.""" -import json import logging import os -import subprocess -from dataclasses import dataclass logger = logging.getLogger(__name__) @@ -13,7 +10,7 @@ def is_rocm_system() -> bool: - """Detect if the system has AMD ROCm available (before or after torch import). + """Detect if the system has AMD ROCm available. Checks: /opt/rocm (Linux), HIP_PATH (Windows, default C:\\hip), HIP_VISIBLE_DEVICES (any platform), CORRIDORKEY_ROCM=1 (explicit opt-in). @@ -29,7 +26,9 @@ def is_rocm_system() -> bool: def setup_rocm_env() -> None: """Set ROCm environment variables and apply optional patches. - Must be called before importing torch. Safe to call on non-ROCm systems (no-op). + These env vars are read by PyTorch/MIOpen at operation time (not import + time), so calling this after ``import torch`` is fine. Safe to call on + non-ROCm systems (no-op). """ if not is_rocm_system(): return @@ -43,11 +42,13 @@ def setup_rocm_env() -> None: import pytorch_rocm_gtt pytorch_rocm_gtt.patch() + except ImportError: + pass # not installed — expected on most systems except Exception: - pass # not installed, or patch failed — non-fatal + logger.warning("pytorch-rocm-gtt is installed but patch() failed", exc_info=True) -import torch # noqa: E402 — deferred so setup_rocm_env() can run first +import torch # noqa: E402 def detect_best_device() -> str: @@ -106,164 +107,6 @@ def resolve_device(requested: str | None = None) -> str: return device -@dataclass -class GPUInfo: - """Information about a single GPU.""" - - index: int - name: str - vram_total_gb: float - vram_free_gb: float - - -def _enumerate_nvidia() -> list[GPUInfo] | None: - """Enumerate NVIDIA GPUs via nvidia-smi. Returns None if unavailable.""" - try: - result = subprocess.run( - [ - "nvidia-smi", - "--query-gpu=index,name,memory.total,memory.free", - "--format=csv,nounits,noheader", - ], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode != 0: - return None - gpus: list[GPUInfo] = [] - for line in result.stdout.strip().split("\n"): - parts = [p.strip() for p in line.split(",")] - if len(parts) >= 4: - gpus.append( - GPUInfo( - index=int(parts[0]), - name=parts[1], - vram_total_gb=float(parts[2]) / 1024, - vram_free_gb=float(parts[3]) / 1024, - ) - ) - return gpus - except (FileNotFoundError, subprocess.TimeoutExpired): - return None - - -def _enumerate_amd() -> list[GPUInfo] | None: - """Enumerate AMD GPUs via amd-smi (ROCm). Returns None if unavailable. - - Tries amd-smi first (modern), then rocm-smi (legacy). - """ - # Try amd-smi (ROCm 6.0+) - try: - result = subprocess.run( - ["amd-smi", "static", "--json"], - capture_output=True, - text=True, - timeout=10, - ) - if result.returncode == 0: - data = json.loads(result.stdout) - gpus: list[GPUInfo] = [] - for i, gpu in enumerate(data): - try: - name = gpu.get("asic", {}).get("market_name", f"AMD GPU {i}") - vram_info = gpu.get("vram", {}) - total_mb = vram_info.get("size", {}).get("value", 0) - total_gb = float(total_mb) / 1024 if total_mb else 0 - gpus.append(GPUInfo(index=i, name=name, vram_total_gb=total_gb, vram_free_gb=total_gb)) - except (KeyError, TypeError, ValueError): - logger.debug("Failed to parse amd-smi entry %d, skipping", i) - if gpus: - # Try to get live VRAM usage from monitor - try: - mon = subprocess.run( - ["amd-smi", "monitor", "--vram", "--json"], - capture_output=True, - text=True, - timeout=5, - ) - if mon.returncode == 0: - mon_data = json.loads(mon.stdout) - for entry in mon_data: - idx = entry.get("gpu", 0) - used_pct = entry.get("vram_use", 0) - if idx < len(gpus) and gpus[idx].vram_total_gb > 0: - used_gb = gpus[idx].vram_total_gb * float(used_pct) / 100 - gpus[idx].vram_free_gb = gpus[idx].vram_total_gb - used_gb - except Exception: - pass - return gpus - except (FileNotFoundError, subprocess.TimeoutExpired, json.JSONDecodeError): - pass - - # Fallback: rocm-smi (legacy, deprecated but still ships) - try: - result = subprocess.run( - ["rocm-smi", "--showid", "--showmeminfo", "vram", "--csv"], - capture_output=True, - text=True, - timeout=10, - ) - if result.returncode == 0 and result.stdout.strip(): - gpus = [] - for line in result.stdout.strip().split("\n")[1:]: # skip header - parts = [p.strip() for p in line.split(",")] - if len(parts) >= 3: - idx = int(parts[0]) if parts[0].isdigit() else len(gpus) - total_b = int(parts[1]) if parts[1].isdigit() else 0 - used_b = int(parts[2]) if parts[2].isdigit() else 0 - gpus.append( - GPUInfo( - index=idx, - name=f"AMD GPU {idx}", - vram_total_gb=total_b / (1024**3), - vram_free_gb=(total_b - used_b) / (1024**3), - ) - ) - if gpus: - return gpus - except (FileNotFoundError, subprocess.TimeoutExpired): - pass - - return None - - -def enumerate_gpus() -> list[GPUInfo]: - """List all available GPUs with VRAM info. - - Tries nvidia-smi (NVIDIA), then amd-smi/rocm-smi (AMD ROCm), - then falls back to torch.cuda API. - Returns an empty list on non-GPU systems. - """ - # NVIDIA - gpus = _enumerate_nvidia() - if gpus is not None: - return gpus - - # AMD ROCm - gpus = _enumerate_amd() - if gpus is not None: - return gpus - - # Fallback to torch (works for both NVIDIA and ROCm via HIP) - if torch.cuda.is_available(): - fallback: list[GPUInfo] = [] - for i in range(torch.cuda.device_count()): - props = torch.cuda.get_device_properties(i) - total = props.total_memory / (1024**3) - fallback.append( - GPUInfo( - index=i, - name=props.name, - vram_total_gb=total, - vram_free_gb=total, # can't query free without setting device - ) - ) - return fallback - - return [] - - def clear_device_cache(device: torch.device | str) -> None: """Clear GPU memory cache if applicable (no-op for CPU).""" device_type = device.type if isinstance(device, torch.device) else device From f71edffc8a2253feab25b644b5ff81a9fd071414 Mon Sep 17 00:00:00 2001 From: James Nye Date: Thu, 26 Mar 2026 13:37:03 -0700 Subject: [PATCH 4/7] feat: add rocm as a proper uv extra instead of manual pip install Use uv's explicit index + extra system to route torch/torchvision to the ROCm wheel index, matching the existing CUDA pattern. Adds pytorch-triton-rocm as an explicit ROCm dependency since the ROCm index is marked explicit=true. Linux native setup is now just `uv sync --extra rocm`. WSL2 and Windows native still require manual install (different wheel sources). --- README.md | 4 +--- pyproject.toml | 27 +++++++++++++++++++++------ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b9ee3800..075fcd98 100644 --- a/README.md +++ b/README.md @@ -235,9 +235,7 @@ CorridorKey supports AMD GPUs via PyTorch's ROCm/HIP backend. The `torch.cuda.*` **Linux native (recommended):** ```bash -# Install AMD's ROCm torch wheels, then sync everything else -pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/rocm6.3 -uv sync +uv sync --extra rocm # Verify uv run python -c "import torch; print(torch.cuda.is_available(), torch.cuda.get_device_name(0))" diff --git a/pyproject.toml b/pyproject.toml index fcd3fca5..e282c48b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,10 +50,11 @@ cuda = [ mlx = [ "corridorkey-mlx ; python_version >= '3.11'", ] -# ROCm (AMD GPU): torch must be installed manually from the ROCm wheel index -# before running uv sync — see README.md "AMD ROCm Setup" for instructions. -# A rocm extra is not provided because the ROCm wheel metadata conflicts -# with PyPI's nvidia-* packages, making uv sync --extra rocm unreliable. +rocm = [ + "torch==2.8.0", + "torchvision==0.23.0", + "pytorch-triton-rocm==3.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'", +] [dependency-groups] dev = ["pytest", "pytest-cov", "ruff", "hypothesis"] @@ -115,6 +116,7 @@ conflicts = [ [ { extra = "cuda" }, { extra = "mlx" }, + { extra = "rocm" }, ], ] @@ -124,9 +126,22 @@ url = "https://download.pytorch.org/whl/cu128" # CUDA 12.6 doesn't support RTX 5 explicit = true extra = "cuda" +[[tool.uv.index]] +name = "pytorch-rocm" +url = "https://download.pytorch.org/whl/rocm6.3" +explicit = true +extra = "rocm" + [tool.uv.sources] # Use Hiera fix in order to utilize the FlashAttention Kernel timm = { git = "https://github.com/Raiden129/pytorch-image-models-fix", branch = "fix/hiera-flash-attention-global-4d" } -torch = { index = "pytorch", extra = "cuda" } -torchvision = { index = "pytorch", extra = "cuda" } +torch = [ + { index = "pytorch", extra = "cuda" }, + { index = "pytorch-rocm", extra = "rocm" }, +] +torchvision = [ + { index = "pytorch", extra = "cuda" }, + { index = "pytorch-rocm", extra = "rocm" }, +] +pytorch-triton-rocm = { index = "pytorch-rocm", extra = "rocm" } corridorkey-mlx = { git = "https://github.com/nikopueringer/corridorkey-mlx.git", extra = "mlx" } From ebe684da7aa69986021e747b5485268d8064044f Mon Sep 17 00:00:00 2001 From: James Nye Date: Thu, 26 Mar 2026 14:09:01 -0700 Subject: [PATCH 5/7] chore: update uv.lock with rocm extra --- uv.lock | 407 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 219 insertions(+), 188 deletions(-) diff --git a/uv.lock b/uv.lock index 7e8fe882..278a0aeb 100644 --- a/uv.lock +++ b/uv.lock @@ -2,22 +2,14 @@ version = 1 revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version >= '3.12' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version == '3.11.*' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version < '3.11' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", ] conflicts = [[ { package = "corridorkey", extra = "cuda" }, { package = "corridorkey", extra = "mlx" }, + { package = "corridorkey", extra = "rocm" }, ]] [manifest] @@ -29,14 +21,15 @@ version = "1.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "packaging" }, { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/4a/8e/ac2a9566747a93f8be36ee08532eb0160558b07630a081a6056a9f89bf1d/accelerate-1.12.0.tar.gz", hash = "sha256:70988c352feb481887077d2ab845125024b2a137a5090d6d7a32b57d03a45df6", size = 398399, upload-time = "2025-11-21T11:27:46.973Z" } wheels = [ @@ -57,9 +50,9 @@ name = "anyio" version = "4.12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "idna" }, - { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" } wheels = [ @@ -196,7 +189,7 @@ name = "click" version = "8.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" } wheels = [ @@ -217,13 +210,10 @@ name = "contourpy" version = "1.3.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version < '3.11' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/66/54/eb9bfc647b19f2009dd5c7f5ec51c4e6ca831725f1aea7a993034f483147/contourpy-1.3.2.tar.gz", hash = "sha256:b6945942715a034c671b7fc54f9588126b0b8bf23db2696e3ca8328f3ff0ab54", size = 13466130, upload-time = "2025-04-15T17:47:53.79Z" } wheels = [ @@ -290,17 +280,11 @@ name = "contourpy" version = "1.3.3" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version >= '3.12' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version == '3.11.*' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] dependencies = [ - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/01/1253e6698a07380cd31a736d248a3f2a50a7c88779a1813da27503cadc2a/contourpy-1.3.3.tar.gz", hash = "sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880", size = 13466174, upload-time = "2025-07-26T12:03:12.549Z" } wheels = [ @@ -369,8 +353,8 @@ dependencies = [ { name = "imageio" }, { name = "kornia" }, { name = "matplotlib" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "opencv-python" }, { name = "peft" }, { name = "pillow" }, @@ -378,13 +362,15 @@ dependencies = [ { name = "rich" }, { name = "setuptools" }, { name = "timm" }, - { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, - { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torchvision", version = "0.23.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torchvision", version = "0.23.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torchvision", version = "0.23.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, { name = "tqdm" }, { name = "transformers" }, - { name = "triton-windows", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "triton-windows", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "typer" }, ] @@ -396,6 +382,11 @@ cuda = [ mlx = [ { name = "corridorkey-mlx", marker = "python_full_version >= '3.11'" }, ] +rocm = [ + { name = "pytorch-triton-rocm", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torch", version = "2.8.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" } }, + { name = "torchvision", version = "0.23.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" } }, +] [package.dev-dependencies] dev = [ @@ -425,19 +416,22 @@ requires-dist = [ { name = "peft" }, { name = "pillow" }, { name = "pims" }, + { name = "pytorch-triton-rocm", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'rocm'", specifier = "==3.4.0", index = "https://download.pytorch.org/whl/rocm6.3", conflict = { package = "corridorkey", extra = "rocm" } }, { name = "rich", specifier = ">=13" }, { name = "setuptools" }, { name = "timm", git = "https://github.com/Raiden129/pytorch-image-models-fix?branch=fix%2Fhiera-flash-attention-global-4d" }, { name = "torch", specifier = "==2.8.0" }, { name = "torch", marker = "extra == 'cuda'", specifier = "==2.8.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "corridorkey", extra = "cuda" } }, + { name = "torch", marker = "extra == 'rocm'", specifier = "==2.8.0", index = "https://download.pytorch.org/whl/rocm6.3", conflict = { package = "corridorkey", extra = "rocm" } }, { name = "torchvision", specifier = "==0.23.0" }, { name = "torchvision", marker = "extra == 'cuda'", specifier = "==0.23.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "corridorkey", extra = "cuda" } }, + { name = "torchvision", marker = "extra == 'rocm'", specifier = "==0.23.0", index = "https://download.pytorch.org/whl/rocm6.3", conflict = { package = "corridorkey", extra = "rocm" } }, { name = "tqdm" }, { name = "transformers" }, { name = "triton-windows", marker = "sys_platform == 'win32'", specifier = "==3.4.0.post21" }, { name = "typer", specifier = ">=0.12" }, ] -provides-extras = ["cuda", "mlx"] +provides-extras = ["cuda", "mlx", "rocm"] [package.metadata.requires-dev] dev = [ @@ -547,7 +541,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version <= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "tomli", marker = "python_full_version <= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] [[package]] @@ -577,8 +571,8 @@ dependencies = [ { name = "httpx" }, { name = "huggingface-hub" }, { name = "importlib-metadata" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "pillow" }, { name = "regex" }, { name = "requests" }, @@ -612,7 +606,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -746,7 +740,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, - { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "httpx" }, { name = "packaging" }, { name = "pyyaml" }, @@ -764,7 +758,7 @@ name = "hypothesis" version = "6.151.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "sortedcontainers" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/e1/ef365ff480903b929d28e057f57b76cae51a30375943e33374ec9a165d9c/hypothesis-6.151.9.tar.gz", hash = "sha256:2f284428dda6c3c48c580de0e18470ff9c7f5ef628a647ee8002f38c3f9097ca", size = 463534, upload-time = "2026-02-16T22:59:23.09Z" } @@ -786,8 +780,8 @@ name = "imageio" version = "2.37.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "pillow" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a3/6f/606be632e37bf8d05b253e8626c2291d74c691ddc7bcdf7d6aaf33b32f6a/imageio-2.37.2.tar.gz", hash = "sha256:0212ef2727ac9caa5ca4b2c75ae89454312f440a756fcfc8ef1993e718f50f8a", size = 389600, upload-time = "2025-11-04T14:29:39.898Z" } @@ -917,8 +911,9 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "kornia-rs" }, { name = "packaging" }, - { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c6/e6/45e757d4924176e4d4e111e10effaab7db382313243e0188a06805010073/kornia-0.8.2.tar.gz", hash = "sha256:5411b2ce0dd909d1608016308cd68faeef90f88c47f47e8ecd40553fd4d8b937", size = 667151, upload-time = "2025-11-08T12:10:03.042Z" } wheels = [ @@ -1047,13 +1042,13 @@ name = "matplotlib" version = "3.10.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "contourpy", version = "1.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "contourpy", version = "1.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "contourpy", version = "1.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "contourpy", version = "1.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "cycler" }, { name = "fonttools" }, { name = "kiwisolver" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "packaging" }, { name = "pillow" }, { name = "pyparsing" }, @@ -1166,10 +1161,7 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version < '3.11' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } wheels = [ @@ -1181,14 +1173,8 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version >= '3.12' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version == '3.11.*' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1200,10 +1186,7 @@ name = "numpy" version = "2.2.6" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version < '3.11' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/76/21/7d2a95e4bba9dc13d043ee156a356c0a8f0c6309dff6b21b4d71a073b8a8/numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd", size = 20276440, upload-time = "2025-05-17T22:38:04.611Z" } wheels = [ @@ -1268,14 +1251,8 @@ name = "numpy" version = "2.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version >= '3.12' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version == '3.11.*' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/57/fd/0005efbd0af48e55eb3c7208af93f2862d4b1a56cd78e84309a2d959208d/numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae", size = 20723651, upload-time = "2026-01-31T23:13:10.135Z" } wheels = [ @@ -1376,7 +1353,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "nvidia-cublas-cu12", marker = "extra == 'extra-11-corridorkey-cuda' or extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-rocm'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -1389,7 +1366,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-11-corridorkey-cuda' or extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-rocm'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -1421,9 +1398,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "nvidia-cublas-cu12", marker = "extra == 'extra-11-corridorkey-cuda' or extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-rocm'" }, + { name = "nvidia-cusparse-cu12", marker = "extra == 'extra-11-corridorkey-cuda' or extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-rocm'" }, + { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-11-corridorkey-cuda' or extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-rocm'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -1436,7 +1413,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-11-corridorkey-cuda' or extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-rocm'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -1488,8 +1465,8 @@ name = "opencv-python" version = "4.13.0.92" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fc/6f/5a28fef4c4a382be06afe3938c64cc168223016fa520c5abaf37e8862aa5/opencv_python-4.13.0.92-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:caf60c071ec391ba51ed00a4a920f996d0b64e3e46068aac1f646b5de0326a19", size = 46247052, upload-time = "2026-02-05T07:01:25.046Z" }, @@ -1518,14 +1495,15 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "accelerate" }, { name = "huggingface-hub" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "packaging" }, { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -1613,12 +1591,12 @@ version = "0.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "imageio" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "packaging" }, { name = "slicerator" }, - { name = "tifffile", version = "2025.5.10", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "tifffile", version = "2026.2.24", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "tifffile", version = "2025.5.10", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "tifffile", version = "2026.2.24", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b8/02/5bf3639f5b77e9b183011c08541c5039ba3d04f5316c70312b48a8e003a9/pims-0.7.tar.gz", hash = "sha256:55907a4c301256086d2aa4e34a5361b9109f24e375c2071e1117b9491e82946b", size = 87779, upload-time = "2024-06-10T19:20:42.842Z" } @@ -1698,13 +1676,13 @@ name = "pytest" version = "9.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "iniconfig" }, { name = "packaging" }, { name = "pluggy" }, { name = "pygments" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } wheels = [ @@ -1737,6 +1715,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "pytorch-triton-rocm" +version = "3.4.0" +source = { registry = "https://download.pytorch.org/whl/rocm6.3" } +dependencies = [ + { name = "setuptools", marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/pytorch_triton_rocm-3.4.0-cp310-cp310-linux_x86_64.whl", hash = "sha256:1ee0a5cf569175e63b43bc334dcaaf6f9b0d88eb455a452869c2bab14e1f7eb4" }, + { url = "https://download.pytorch.org/whl/pytorch_triton_rocm-3.4.0-cp311-cp311-linux_x86_64.whl", hash = "sha256:b0362725d8e16d185251e3dcd48455ebf9cdaad2c26052bb47ef08a1d687ed20" }, + { url = "https://download.pytorch.org/whl/pytorch_triton_rocm-3.4.0-cp312-cp312-linux_x86_64.whl", hash = "sha256:7afe951b9fc38f1a5b3a7b98bebbaa092bf51e6192b699b4fade9b1ad6fc9c2c" }, + { url = "https://download.pytorch.org/whl/pytorch_triton_rocm-3.4.0-cp313-cp313-linux_x86_64.whl", hash = "sha256:1e7ccba3501fcd38e8cd8415f97a654043370e1fdc5a936bb75abe1bebeb94c9" }, + { url = "https://download.pytorch.org/whl/pytorch_triton_rocm-3.4.0-cp313-cp313t-linux_x86_64.whl", hash = "sha256:c262cd42e38b6955391338cca1c3a779cceb8c51e4b45200d87305c870ef99d7" }, +] + [[package]] name = "pyyaml" version = "6.0.3" @@ -2013,13 +2006,10 @@ name = "tifffile" version = "2025.5.10" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version < '3.11' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/44/d0/18fed0fc0916578a4463f775b0fbd9c5fed2392152d039df2fb533bfdd5d/tifffile-2025.5.10.tar.gz", hash = "sha256:018335d34283aa3fd8c263bae5c3c2b661ebc45548fde31504016fcae7bf1103", size = 365290, upload-time = "2025-05-10T19:22:34.386Z" } wheels = [ @@ -2031,17 +2021,11 @@ name = "tifffile" version = "2026.2.24" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version >= '3.12' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx') or (python_full_version == '3.11.*' and sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", - "python_full_version == '3.11.*' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-mlx'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] dependencies = [ - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6e/1c/19fc653e2b05ec0defae511b03b330ca60c95f2c47fcaaf21c52c6e84aa8/tifffile-2026.2.24.tar.gz", hash = "sha256:d73cfa6d7a8f5775a1e3c9f3bfca77c992946639fb41a5bbe888878cb6964dc6", size = 387373, upload-time = "2026-02-24T23:59:11.706Z" } wheels = [ @@ -2056,10 +2040,12 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, - { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torchvision", version = "0.23.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torchvision", version = "0.23.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torchvision", version = "0.23.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torchvision", version = "0.23.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, ] [[package]] @@ -2138,29 +2124,29 @@ resolution-markers = [ "python_full_version < '3.11'", ] dependencies = [ - { name = "filelock", marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "fsspec", marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "jinja2", marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "sympy", marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "typing-extensions", marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, + { name = "filelock", marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "fsspec", marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "jinja2", marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-mlx') or (python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-mlx') or (python_full_version >= '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-11-corridorkey-mlx') or (python_full_version >= '3.12' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "sympy", marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-mlx') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "typing-extensions", marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/63/28/110f7274254f1b8476c561dada127173f994afa2b1ffc044efb773c15650/torch-2.8.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:0be92c08b44009d4131d1ff7a8060d10bafdb7ddcb7359ef8d8c5169007ea905", size = 102052793, upload-time = "2025-08-06T14:53:15.852Z" }, @@ -2190,37 +2176,34 @@ name = "torch" version = "2.8.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "filelock", marker = "extra == 'extra-11-corridorkey-cuda'" }, - { name = "fsspec", marker = "extra == 'extra-11-corridorkey-cuda'" }, - { name = "jinja2", marker = "extra == 'extra-11-corridorkey-cuda'" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "sympy", marker = "extra == 'extra-11-corridorkey-cuda'" }, - { name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "typing-extensions", marker = "extra == 'extra-11-corridorkey-cuda'" }, + { name = "filelock", marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "fsspec", marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "jinja2", marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cuda-cupti-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cuda-runtime-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cudnn-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cufft-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cufile-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "sympy", marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "triton", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-cuda') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "typing-extensions", marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0c96999d15cf1f13dd7c913e0b21a9a355538e6cfc10861a17158320292f5954" }, @@ -2235,6 +2218,34 @@ wheels = [ { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:970b4f4661fa7b44f6a7e6df65de7fc4a6fff2af610dc415c1d695ca5f1f37d2" }, ] +[[package]] +name = "torch" +version = "2.8.0+rocm6.3" +source = { registry = "https://download.pytorch.org/whl/rocm6.3" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", +] +dependencies = [ + { name = "filelock", marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "fsspec", marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "jinja2", marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "pytorch-triton-rocm", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-11-corridorkey-rocm') or (platform_machine != 'x86_64' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (sys_platform != 'linux' and extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "setuptools", marker = "(python_full_version >= '3.12' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "sympy", marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "typing-extensions", marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/rocm6.3/torch-2.8.0%2Brocm6.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a492665402d52a71dc3d3c025153ead15ef9d993d1f39abd67dbf8b3d0fadd35" }, + { url = "https://download.pytorch.org/whl/rocm6.3/torch-2.8.0%2Brocm6.3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:8312c2defbd8890fd3cf2fc01cfcc53678fe00891d1c8314901025991e21f081" }, + { url = "https://download.pytorch.org/whl/rocm6.3/torch-2.8.0%2Brocm6.3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:de0bb7f98c977ad9c74312fc34d375747efc5c5ddee60a4cd5163fa729211602" }, + { url = "https://download.pytorch.org/whl/rocm6.3/torch-2.8.0%2Brocm6.3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2cdaac5545d66f11030914ae03c85d161f24b14053b67f49500a8f52638ef104" }, + { url = "https://download.pytorch.org/whl/rocm6.3/torch-2.8.0%2Brocm6.3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:656892d4f59d1138aaade39fafa4a317a81a13238288a98a654a75cf80349774" }, +] + [[package]] name = "torchvision" version = "0.23.0" @@ -2245,10 +2256,10 @@ resolution-markers = [ "python_full_version < '3.11'", ] dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra != 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "pillow", marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, - { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-cuda'" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-mlx') or (python_full_version < '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-mlx') or (python_full_version >= '3.11' and extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "pillow", marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-corridorkey-mlx' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra != 'extra-11-corridorkey-rocm')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/4d/49/5ad5c3ff4920be0adee9eb4339b4fb3b023a0fc55b9ed8dbc73df92946b8/torchvision-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7266871daca00ad46d1c073e55d972179d12a58fa5c9adec9a3db9bbed71284a", size = 1856885, upload-time = "2025-08-06T14:57:55.024Z" }, @@ -2278,18 +2289,15 @@ name = "torchvision" version = "0.23.0+cu128" source = { registry = "https://download.pytorch.org/whl/cu128" } resolution-markers = [ - "(python_full_version >= '3.12' and platform_machine != 'aarch64') or (python_full_version >= '3.12' and sys_platform != 'linux')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64') or (python_full_version == '3.11.*' and sys_platform != 'linux')", - "(python_full_version < '3.11' and platform_machine != 'aarch64') or (python_full_version < '3.11' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "pillow", marker = "extra == 'extra-11-corridorkey-cuda'" }, - { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda'" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-cuda') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "pillow", marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] wheels = [ { url = "https://download-r2.pytorch.org/whl/cu128/torchvision-0.23.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:460bc8d70f63bdb433a7351decc2c1ae1903f7f378e4a7614fc8e8c97a5c36aa" }, @@ -2304,12 +2312,35 @@ wheels = [ { url = "https://download-r2.pytorch.org/whl/cu128/torchvision-0.23.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:91fd897fb6fefaf25ec56897391b448eff73f28a7e2ab7660886ece85c865ec6" }, ] +[[package]] +name = "torchvision" +version = "0.23.0+rocm6.3" +source = { registry = "https://download.pytorch.org/whl/rocm6.3" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version < '3.11'", +] +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "pillow", marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "torch", version = "2.8.0+rocm6.3", source = { registry = "https://download.pytorch.org/whl/rocm6.3" }, marker = "(extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra != 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm') or (extra != 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm')" }, +] +wheels = [ + { url = "https://download-r2.pytorch.org/whl/rocm6.3/torchvision-0.23.0%2Brocm6.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c67eae11e74b5e03a2cd5d4eabb3c4fb12a21c261aaac5672765961f4bbc9197" }, + { url = "https://download-r2.pytorch.org/whl/rocm6.3/torchvision-0.23.0%2Brocm6.3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:49b31481565fd410e5f43226a06d7b27a4dfe10e91a0a4c9f48fdffc9bf88720" }, + { url = "https://download-r2.pytorch.org/whl/rocm6.3/torchvision-0.23.0%2Brocm6.3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:53c503ceaea50ac20581bdbb2e088997f07c5da22e24850333533719681faedb" }, + { url = "https://download-r2.pytorch.org/whl/rocm6.3/torchvision-0.23.0%2Brocm6.3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:5d6b076b857ab8643dd796090499fa3b774822182eb4b27bfe2b0bc7f04289a5" }, + { url = "https://download-r2.pytorch.org/whl/rocm6.3/torchvision-0.23.0%2Brocm6.3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:5ea4c7e9708e5864dde7302d7bf75aaff9f39d58aa7e833febb50c0fd83dd0a0" }, +] + [[package]] name = "tqdm" version = "4.67.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/09/a9/6ba95a270c6f1fbcd8dac228323f2777d886cb206987444e4bce66338dd4/tqdm-4.67.3.tar.gz", hash = "sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb", size = 169598, upload-time = "2026-02-03T17:35:53.048Z" } wheels = [ @@ -2322,8 +2353,8 @@ version = "5.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, { name = "packaging" }, { name = "pyyaml" }, { name = "regex" }, @@ -2342,7 +2373,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "setuptools", marker = "extra == 'extra-11-corridorkey-cuda' or extra == 'extra-11-corridorkey-mlx' or extra != 'extra-11-corridorkey-rocm'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, @@ -2357,7 +2388,7 @@ name = "triton-windows" version = "3.4.0.post21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "platform_machine != 'aarch64' or sys_platform != 'linux' or extra != 'extra-11-corridorkey-cuda' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "setuptools" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9d/77/58e419d467770ccabe83650f22e35cd5cba7e8fe23294fb49621d3c2aa57/triton_windows-3.4.0.post21-cp310-cp310-win_amd64.whl", hash = "sha256:b5d12acb71d8fb4af8baff7588b5d26614bbe383888c22ab27b50948dc3baa82", size = 42672843, upload-time = "2025-10-17T05:09:05.38Z" }, @@ -2422,7 +2453,7 @@ dependencies = [ { name = "pygments" }, { name = "pymdown-extensions" }, { name = "pyyaml" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx')" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-mlx') or (extra == 'extra-11-corridorkey-cuda' and extra == 'extra-11-corridorkey-rocm') or (extra == 'extra-11-corridorkey-mlx' and extra == 'extra-11-corridorkey-rocm')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3b/96/9c6cbdd7b351d1023cdbbcf7872d4cb118b0334cfe5821b99e0dd18e3f00/zensical-0.0.24.tar.gz", hash = "sha256:b5d99e225329bf4f98c8022bdf0a0ee9588c2fada7b4df1b7b896fcc62b37ec3", size = 3840688, upload-time = "2026-02-26T09:43:44.557Z" } wheels = [ From dfdf14d53a382a0b06632cdd07d35ac315ad0f0b Mon Sep 17 00:00:00 2001 From: James Nye Date: Thu, 26 Mar 2026 14:17:17 -0700 Subject: [PATCH 6/7] fix: make setup_rocm_env() run before torch import, fix HIP_PATH docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move torch import in device_utils.py from module level into the functions that need it, so importing setup_rocm_env no longer triggers a torch import. Restructure corridorkey_cli.py to call setup_rocm_env() before any torch-importing modules. Remove redundant call from inference_engine.py. Fix is_rocm_system() docstring — remove false claim about checking C:\hip path; it only checks the HIP_PATH env var. Update tests for list-style torch sources and lazy torch import. --- CorridorKeyModule/inference_engine.py | 6 ------ corridorkey_cli.py | 14 ++++++++------ device_utils.py | 25 +++++++++++++++---------- tests/test_device_utils.py | 2 +- tests/test_pyproject_structure.py | 22 +++++++++++++++++----- 5 files changed, 41 insertions(+), 28 deletions(-) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py index ad079ebd..476ebffa 100755 --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -13,15 +13,9 @@ import torchvision.transforms.v2 as T import torchvision.transforms.v2.functional as TF -from device_utils import setup_rocm_env - from .core import color_utils as cu from .core.model_transformer import GreenFormer -# ROCm env vars are read at operation time (not import time), so this is -# fine after importing torch. Also sets up pytorch-rocm-gtt if installed. -setup_rocm_env() - # Persist torch.compile autotune cache across runs (default is /tmp which # gets wiped on reboot — saves 10-20 min re-autotuning on ROCm, ~30s on CUDA) _inductor_cache = os.path.join(os.path.expanduser("~"), ".cache", "corridorkey", "inductor") diff --git a/corridorkey_cli.py b/corridorkey_cli.py index b3e991a4..0d015e7c 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -37,7 +37,12 @@ from rich.prompt import Confirm, IntPrompt, Prompt from rich.table import Table -from clip_manager import ( +# Set ROCm env vars before any module imports torch. +from device_utils import setup_rocm_env + +setup_rocm_env() + +from clip_manager import ( # noqa: E402 LINUX_MOUNT_ROOT, ClipEntry, InferenceSettings, @@ -51,11 +56,8 @@ run_videomama, scan_clips, ) -from CorridorKeyModule.backend import resolve_backend -from device_utils import resolve_device, setup_rocm_env - -# ROCm env vars are read at operation time, so this is fine after imports. -setup_rocm_env() +from CorridorKeyModule.backend import resolve_backend # noqa: E402 +from device_utils import resolve_device # noqa: E402 logger = logging.getLogger(__name__) console = Console() diff --git a/device_utils.py b/device_utils.py index 6778c230..58125b9d 100644 --- a/device_utils.py +++ b/device_utils.py @@ -10,10 +10,11 @@ def is_rocm_system() -> bool: - """Detect if the system has AMD ROCm available. + """Detect if the system has AMD ROCm available (without importing torch). - Checks: /opt/rocm (Linux), HIP_PATH (Windows, default C:\\hip), - HIP_VISIBLE_DEVICES (any platform), CORRIDORKEY_ROCM=1 (explicit opt-in). + Checks: /opt/rocm (Linux), HIP_PATH env var (Windows), HIP_VISIBLE_DEVICES + (any platform), CORRIDORKEY_ROCM=1 (explicit opt-in for cases where + auto-detection fails, e.g. pip-installed ROCm on Windows). """ return ( os.path.exists("/opt/rocm") @@ -26,9 +27,10 @@ def is_rocm_system() -> bool: def setup_rocm_env() -> None: """Set ROCm environment variables and apply optional patches. - These env vars are read by PyTorch/MIOpen at operation time (not import - time), so calling this after ``import torch`` is fine. Safe to call on - non-ROCm systems (no-op). + Must be called before importing torch so that env vars are visible to + PyTorch's initialization. This module intentionally avoids importing + torch at module level to make that possible. Safe to call on non-ROCm + systems (no-op). """ if not is_rocm_system(): return @@ -48,11 +50,10 @@ def setup_rocm_env() -> None: logger.warning("pytorch-rocm-gtt is installed but patch() failed", exc_info=True) -import torch # noqa: E402 - - def detect_best_device() -> str: """Auto-detect best available device: CUDA > MPS > CPU.""" + import torch + if torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): @@ -76,6 +77,8 @@ def resolve_device(requested: str | None = None) -> str: Raises: RuntimeError: If the requested backend is unavailable. """ + import torch + # CLI arg takes priority, then env var, then auto device = requested if device is None or device == "auto": @@ -107,8 +110,10 @@ def resolve_device(requested: str | None = None) -> str: return device -def clear_device_cache(device: torch.device | str) -> None: +def clear_device_cache(device) -> None: """Clear GPU memory cache if applicable (no-op for CPU).""" + import torch + device_type = device.type if isinstance(device, torch.device) else device if device_type == "cuda": torch.cuda.empty_cache() diff --git a/tests/test_device_utils.py b/tests/test_device_utils.py index f323baca..4a80da45 100644 --- a/tests/test_device_utils.py +++ b/tests/test_device_utils.py @@ -119,7 +119,7 @@ def test_mps_no_backend_raises(self, monkeypatch): _patch_gpu(monkeypatch, cuda=False, mps=False) # Replace torch.backends with an object that lacks "mps" entirely fake_backends = type("Backends", (), {})() - monkeypatch.setattr("device_utils.torch.backends", fake_backends) + monkeypatch.setattr(torch, "backends", fake_backends) with pytest.raises(RuntimeError, match="no MPS support"): resolve_device("mps") diff --git a/tests/test_pyproject_structure.py b/tests/test_pyproject_structure.py index 816c2d97..30d332e5 100644 --- a/tests/test_pyproject_structure.py +++ b/tests/test_pyproject_structure.py @@ -86,14 +86,26 @@ class TestUvSources: def test_torch_source_has_cuda_extra(self, pyproject: dict) -> None: sources = pyproject["tool"]["uv"]["sources"] torch_src = sources["torch"] - assert torch_src.get("extra") == "cuda" - assert "marker" not in torch_src, "torch source should not have platform markers" + cuda_entry = next(s for s in torch_src if s.get("extra") == "cuda") + assert cuda_entry["index"] == "pytorch" def test_torchvision_source_has_cuda_extra(self, pyproject: dict) -> None: sources = pyproject["tool"]["uv"]["sources"] tv_src = sources["torchvision"] - assert tv_src.get("extra") == "cuda" - assert "marker" not in tv_src, "torchvision source should not have platform markers" + cuda_entry = next(s for s in tv_src if s.get("extra") == "cuda") + assert cuda_entry["index"] == "pytorch" + + def test_torch_source_has_rocm_extra(self, pyproject: dict) -> None: + sources = pyproject["tool"]["uv"]["sources"] + torch_src = sources["torch"] + rocm_entry = next(s for s in torch_src if s.get("extra") == "rocm") + assert rocm_entry["index"] == "pytorch-rocm" + + def test_torchvision_source_has_rocm_extra(self, pyproject: dict) -> None: + sources = pyproject["tool"]["uv"]["sources"] + tv_src = sources["torchvision"] + rocm_entry = next(s for s in tv_src if s.get("extra") == "rocm") + assert rocm_entry["index"] == "pytorch-rocm" # --------------------------------------------------------------------------- @@ -110,7 +122,7 @@ def test_cuda_mlx_conflict_declared(self, pyproject: dict) -> None: extras_in_groups = [ {entry["extra"] for entry in group} for group in conflicts if all("extra" in entry for entry in group) ] - assert {"cuda", "mlx"} in extras_in_groups, "Expected a conflict group containing both 'cuda' and 'mlx' extras" + assert {"cuda", "mlx", "rocm"} in extras_in_groups, "Expected a conflict group containing 'cuda', 'mlx', and 'rocm' extras" # --------------------------------------------------------------------------- From fe61b01893ee69f8621776600db5444505570834 Mon Sep 17 00:00:00 2001 From: James Nye Date: Thu, 26 Mar 2026 14:19:47 -0700 Subject: [PATCH 7/7] style: format test_pyproject_structure.py --- tests/test_pyproject_structure.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_pyproject_structure.py b/tests/test_pyproject_structure.py index 30d332e5..c11e4b2b 100644 --- a/tests/test_pyproject_structure.py +++ b/tests/test_pyproject_structure.py @@ -122,7 +122,9 @@ def test_cuda_mlx_conflict_declared(self, pyproject: dict) -> None: extras_in_groups = [ {entry["extra"] for entry in group} for group in conflicts if all("extra" in entry for entry in group) ] - assert {"cuda", "mlx", "rocm"} in extras_in_groups, "Expected a conflict group containing 'cuda', 'mlx', and 'rocm' extras" + assert {"cuda", "mlx", "rocm"} in extras_in_groups, ( + "Expected a conflict group containing 'cuda', 'mlx', and 'rocm' extras" + ) # ---------------------------------------------------------------------------