feat: AMD ROCm GPU support — device enumeration, compile mode, SDPA fix#206
Conversation
2c7421d to
7fcd2ae
Compare
|
Hey, thanks for the work on this. This is a solid PR overall, and the runtime changes look genuinely useful. I did want to flag two issues that I think should be fixed before merge.
Right now the rocm = [
"torch==2.8.0",
"torchvision==0.23.0",
]But unlike the The The I know this is partially mentioned in the README and in the comment in Right now it is not, and that is going to create confusion for users. I think this should go one of two ways:
If the extra is not actually functional, I think it is better not to advertise it in the metadata.
The pre-import ROCm detection in _is_rocm_system = os.environ.get("HIP_VISIBLE_DEVICES") is not None or os.path.exists("/opt/rocm")On Linux this is reasonable, because On native Windows, neither of these conditions is likely to be true by default.
ROCm docs: Also, That means on a normal native Windows ROCm install, this block probably does not run: os.environ.setdefault("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1")
os.environ.setdefault("MIOPEN_FIND_MODE", "2")
os.environ.setdefault("MIOPEN_LOG_LEVEL", "0")This does not necessarily mean Windows ROCm is completely broken. The later post-import check: is_rocm = hasattr(torch.version, "hip") and torch.version.hipis much better, and things like skipping The problem is that the pre-import ROCm env vars that the PR and README describe as important probably will not get applied on a normal Windows install. In particular:
So I think the Windows claim should either be narrowed, or the detection logic should be made platform-aware. For example, check Windows ROCm paths as well, or require an explicit opt-in env var for experimental Windows ROCm support. The runtime work, compile handling, cache persistence, and documentation all look useful. I just think these two issues should be cleaned up so the install surface and platform behavior match the claims. |
Add support for AMD GPUs (RDNA3/RDNA4) via ROCm's HIP translation layer. Most PyTorch code works unchanged since HIP maps torch.cuda.* APIs, but three areas needed explicit AMD handling: device_utils.py: - Add GPUInfo dataclass and enumerate_gpus() function - Tries nvidia-smi, then amd-smi (ROCm 6.0+), then rocm-smi (legacy), then torch.cuda fallback — works on NVIDIA, AMD, or either inference_engine.py: - torch.compile uses "max-autotune-no-cudagraphs" on ROCm to avoid a known HIP graph segfault on large graphs (pytorch/pytorch#155720) - Auto-sets TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 on ROCm so SDPA dispatches to AOTriton flash attention on RDNA3 (without this it silently falls back to O(n^2) math backend) - Skips torch.compile entirely on Windows ROCm where Triton kernel compilation hangs indefinitely (eager fallback works fine) pyproject.toml: - Add "rocm" optional extra with install instructions README.md: - Add AMD ROCm to Hardware Requirements - Add ROCm install instructions for Linux and Windows - Add dedicated AMD ROCm Setup section with supported GPUs, automatic behavior, and known limitations Tested on RX 7800 XT (gfx1101) Windows with AMD ROCm 7.2: - torch.cuda.is_available() = True - SDPA with float16 = OK - Raw inference = OK - torch.compile = hangs on Windows (skipped), works on Linux
7fcd2ae to
5acb900
Compare
- Extract is_rocm_system() and setup_rocm_env() into device_utils.py as shared functions. Both corridorkey_cli.py and inference_engine.py now call setup_rocm_env() instead of duplicating the detection + env var logic. - device_utils.py restructured so setup_rocm_env() is defined before `import torch`, allowing callers to set env vars before torch init. - Broaden pytorch_rocm_gtt exception handler from ImportError to Exception (catches patch() runtime errors too, not just missing package). - Add per-entry error handling in _enumerate_amd() — malformed amd-smi JSON entries are skipped instead of crashing the whole enumeration. Add json.JSONDecodeError to the outer exception handler. - Use top-level `import json` instead of local `import json as _json`. - Change MIOPEN_LOG_LEVEL from 0 (silent) to 4 (suppress info/debug, keep warnings and errors visible). - Compile message is now conditional: "10-20 minutes on first run (ROCm)" vs just "Compiling model..." on NVIDIA where it takes ~30 seconds. - README VRAM claim clarified: ~10GB on NVIDIA, ~18GB on AMD due to HIP allocator overhead. Prevents scaring off 12GB NVIDIA card owners.
|
Big moves happening here! Let me know when you feel like this is ready to be merged. |
…e error handling 1. Remove dead code: GPUInfo, enumerate_gpus(), _enumerate_nvidia(), _enumerate_amd() — 158 lines with zero callers in upstream. These are cloud-platform code that doesn't belong in this PR. 2. Remove unused imports (json, subprocess, dataclass) that were only needed by the dead GPU enumeration code. 3. Fix misleading "deferred so setup_rocm_env() can run first" comment. The env vars are read at operation time by PyTorch/MIOpen C libraries, not at torch import time. The code works correctly either way. 4. Remove all noqa: E402 comments — imports are now in normal order since the pre-import ordering was unnecessary. 5. Better exception handling in setup_rocm_env(): ImportError caught separately (expected, silent) vs other exceptions (logged as warning with traceback for debugging broken pytorch-rocm-gtt installs). 6. Cache is_rocm as self._is_rocm instance attribute instead of recomputing hasattr(torch.version, "hip") in both __init__ and _compile.
|
Hi, thanks for the follow-up changes. This is in much better shape now, but i see 2 minor issues worth fixing before merge. 1.
|
MarcelLieb
left a comment
There was a problem hiding this comment.
That would be my suggestion for how a rocm extra could work (Not tested)
pyproject.toml
Outdated
There was a problem hiding this comment.
torch = [
{ index = "pytorch-cu128", extra = "cuda" },
{ index = "pytorch-rocm", extra = "rocm" },
]
torchvision = [
{ index = "pytorch-cu128", extra = "cuda" },
{ index = "pytorch-rocm", extra = "rocm" },
]| url = "https://download.pytorch.org/whl/cu128" # CUDA 12.6 doesn't support RTX 5000 Series | ||
| explicit = true | ||
| extra = "cuda" | ||
|
|
There was a problem hiding this comment.
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128" # CUDA 12.6 doesn't support RTX 5000 Series
explicit = true
[[tool.uv.index]]
name = "pytorch-rocm"
url = "https://download.pytorch.org/whl/rocm6.3"
explicit = true| conflicts = [ | ||
| [ | ||
| { extra = "cuda" }, | ||
| { extra = "mlx" }, |
| @@ -50,6 +50,10 @@ | |||
| mlx = [ | |||
| "corridorkey-mlx ; python_version >= '3.11'", | |||
| ] | |||
There was a problem hiding this comment.
rocm = [
"torch==2.8.0",
"torchvision==0.23.0",
]a7a5478 to
6025c57
Compare
Use uv's explicit index + extra system to route torch/torchvision to the ROCm wheel index, matching the existing CUDA pattern. Adds pytorch-triton-rocm as an explicit ROCm dependency since the ROCm index is marked explicit=true. Linux native setup is now just `uv sync --extra rocm`. WSL2 and Windows native still require manual install (different wheel sources).
6025c57 to
f71edff
Compare
|
Thanks Marcel! Your changes DID work, I am addressing the lint failures and going to be fixing the things Raiden pointed out as well~ |
…tring Move torch import in device_utils.py from module level into the functions that need it, so importing setup_rocm_env no longer triggers a torch import. Restructure corridorkey_cli.py to call setup_rocm_env() before any torch-importing modules. Remove redundant call from inference_engine.py. Fix is_rocm_system() docstring — remove false claim about checking C:\hip path; it only checks the HIP_PATH env var. Update tests for list-style torch sources and lazy torch import.
|
Both fixed, thanks for catching these. 1. setup_rocm_env() now genuinely runs before torch import
# Set ROCm env vars before any module imports torch.
from device_utils import setup_rocm_env
setup_rocm_env()
from clip_manager import ... # noqa: E402
from CorridorKeyModule.backend import resolve_backend # noqa: E402The redundant call in 2. is_rocm_system() docstring fixed Removed the false claim about checking Also in this push: the |
MarcelLieb
left a comment
There was a problem hiding this comment.
Does uv sync --extra rocm not work with WSL?
| # 3. Install AMD's WSL torch wheels (Python 3.12) | ||
| pip3 install \ | ||
| https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp312-cp312-linux_x86_64.whl \ |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
There was a problem hiding this comment.
You're right that the WSL2 section in the README uses different wheels -> torch 2.9.1 from repo.radeon.com which are AMD's WSL-specific builds. Those were originally there because the standard ROCm wheels weren't tested on WSL2 at the time.
Now that uv sync --extra rocm works, the WSL2 section could potentially just point to that instead. The separate AMD WSL wheels + manual ROCm driver install may still be needed for some setups though (older drivers, specific GPU quirks), so I'm inclined to keep both paths documented.



AMD ROCm GPU Support (RDNA3 / RDNA4)
Adds support for AMD Radeon GPUs via PyTorch's ROCm/HIP backend. Most inference code works unchanged because HIP intercepts all
torch.cuda.*calls transparently — the changes are in GPU enumeration, torch.compile configuration, and environment variable setup.What changed
device_utils.pyGPUInfodataclass andenumerate_gpus()functioncorridorkey_cli.pypytorch-rocm-gttif installedCorridorKeyModule/inference_engine.pyimport torch(for direct engine usage):TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1— enables flash attention on RDNA3 (without this, SDPA silently falls back to O(n²) math kernel)MIOPEN_FIND_MODE=2— fast convolution kernel selection (reduces warmup from 5-8 min to seconds)MIOPEN_LOG_LEVEL=0— silences harmless workspace warningstorch.compileusesmode="default"on ROCm (max-autotune OOM-kills 16GB cards during kernel benchmarking)torch.compileskipped entirely on Windows ROCm (Triton compilation hangs indefinitely)pytorch-rocm-gttif installed (GTT system RAM overflow for 16GB cards on Linux)TORCHINDUCTOR_CACHE_DIRpersisted to~/.cache/corridorkey/inductor/(default /tmp gets wiped, causing 10-20 min re-autotune on every reboot)CORRIDORKEY_SKIP_COMPILE=1env var to force eager mode for testingpyproject.tomlrocmoptional extra with install instructions in commentsREADME.mdSupported GPUs
Tested on
Platform comparison
Not changed
torch.version.hipor/opt/rocmchecks)Checklist
uv run pytestpassesuv run ruff checkpassesuv run ruff format --checkpasses